From 548af3c2186b678fa3860104d1457affc4feae8f Mon Sep 17 00:00:00 2001 From: Max Gabrielsson Date: Wed, 26 Jun 2024 15:25:41 +0200 Subject: [PATCH 1/2] add optional support for simsimd, update usearch --- CMakeLists.txt | 8 + src/hnsw/hnsw_index_physical_create.cpp | 3 +- src/include/simsimd/LICENSE | 201 + src/include/simsimd/binary.h | 258 + src/include/simsimd/dot.h | 1671 ++++++ src/include/simsimd/geospatial.h | 37 + src/include/simsimd/probability.h | 573 ++ src/include/simsimd/simsimd.h | 1261 +++++ src/include/simsimd/spatial.h | 1341 +++++ src/include/simsimd/types.h | 423 ++ src/include/usearch/duckdb_usearch.hpp | 8 +- src/include/usearch/index.hpp | 6777 ++++++++++++----------- src/include/usearch/index_dense.hpp | 3478 ++++++------ src/include/usearch/index_plugins.hpp | 3352 +++++------ 14 files changed, 12659 insertions(+), 6732 deletions(-) create mode 100644 src/include/simsimd/LICENSE create mode 100644 src/include/simsimd/binary.h create mode 100644 src/include/simsimd/dot.h create mode 100644 src/include/simsimd/geospatial.h create mode 100644 src/include/simsimd/probability.h create mode 100644 src/include/simsimd/simsimd.h create mode 100644 src/include/simsimd/spatial.h create mode 100644 src/include/simsimd/types.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 6b8aadd..6edbf7e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,6 +7,14 @@ set(EXTENSION_NAME ${TARGET_NAME}_extension) set(LOADABLE_EXTENSION_NAME ${TARGET_NAME}_loadable_extension) project(${TARGET_NAME}) + +option(USE_SIMSIMD "Use SIMSIMD library to sacrifice portability for vectorized search" OFF) +if(USE_SIMSIMD) + add_definitions(-DDUCKDB_USEARCH_USE_SIMSIMD=1) +else() + add_definitions(-DDUCKDB_USEARCH_USE_SIMSIMD=0) +endif() + include_directories(src/include) set(EXTENSION_SOURCES src/vss_extension.cpp) diff --git a/src/hnsw/hnsw_index_physical_create.cpp b/src/hnsw/hnsw_index_physical_create.cpp index e3d6dda..fa771c0 100644 --- a/src/hnsw/hnsw_index_physical_create.cpp +++ b/src/hnsw/hnsw_index_physical_create.cpp @@ -287,8 +287,9 @@ SinkFinalizeType PhysicalCreateHNSWIndex::Finalize(Pipeline &pipeline, Event &ev gstate.is_building = true; // Reserve the index size + auto &ts = TaskScheduler::GetScheduler(context); auto &index = gstate.global_index->index; - index.reserve(collection->Count()); + index.reserve({static_cast(collection->Count()), static_cast(ts.NumberOfThreads())}); // Initialize a parallel scan for the index construction collection->InitializeScan(gstate.scan_state, ColumnDataScanProperties::ALLOW_ZERO_COPY); diff --git a/src/include/simsimd/LICENSE b/src/include/simsimd/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/src/include/simsimd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/src/include/simsimd/binary.h b/src/include/simsimd/binary.h new file mode 100644 index 0000000..b664a21 --- /dev/null +++ b/src/include/simsimd/binary.h @@ -0,0 +1,258 @@ +/** + * @file binary.h + * @brief SIMD-accelerated Binary Similarity Measures. + * @author Ash Vardanian + * @date July 1, 2023 + * + * Contains: + * - Hamming distance + * - Jaccard similarity (Tanimoto coefficient) + * + * For hardware architectures: + * - Arm (NEON, SVE) + * - x86 (AVX2, AVX512) + * + * x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/ + * Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/ + */ +#ifndef SIMSIMD_BINARY_H +#define SIMSIMD_BINARY_H + +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// clang-format off + +/* Serial backends for bitsets. */ +SIMSIMD_PUBLIC void simsimd_hamming_b8_serial(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* distance); +SIMSIMD_PUBLIC void simsimd_jaccard_b8_serial(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* distance); + +/* Arm NEON backend for bitsets. */ +SIMSIMD_PUBLIC void simsimd_hamming_b8_neon(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* distance); +SIMSIMD_PUBLIC void simsimd_jaccard_b8_neon(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* distance); + +/* Arm SVE backend for bitsets. */ +SIMSIMD_PUBLIC void simsimd_hamming_b8_sve(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* distance); +SIMSIMD_PUBLIC void simsimd_jaccard_b8_sve(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* distance); + +/* x86 AVX2 backend for bitsets for Intel Haswell CPUs and newer, needs only POPCNT extensions. */ +SIMSIMD_PUBLIC void simsimd_hamming_b8_haswell(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* distance); +SIMSIMD_PUBLIC void simsimd_jaccard_b8_haswell(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* distance); + +/* x86 AVX512 backend for bitsets for Intel Ice Lake CPUs and newer, using VPOPCNTDQ extensions. */ +SIMSIMD_PUBLIC void simsimd_hamming_b8_ice(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* distance); +SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, simsimd_distance_t* distance); +// clang-format on + +SIMSIMD_PUBLIC unsigned char simsimd_popcount_b8(simsimd_b8_t x) { + static unsigned char lookup_table[] = { + 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, // + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, + 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7, 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8}; + return lookup_table[x]; +} + +SIMSIMD_PUBLIC void simsimd_hamming_b8_serial(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, + simsimd_distance_t* result) { + simsimd_i32_t differences = 0; + for (simsimd_size_t i = 0; i != n_words; ++i) + differences += simsimd_popcount_b8(a[i] ^ b[i]); + *result = differences; +} + +SIMSIMD_PUBLIC void simsimd_jaccard_b8_serial(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, + simsimd_distance_t* result) { + simsimd_i32_t intersection = 0, union_ = 0; + for (simsimd_size_t i = 0; i != n_words; ++i) + intersection += simsimd_popcount_b8(a[i] & b[i]), union_ += simsimd_popcount_b8(a[i] | b[i]); + *result = (union_ != 0) ? 1 - (simsimd_f32_t)intersection / (simsimd_f32_t)union_ : 0; +} + +#if SIMSIMD_TARGET_ARM +#if SIMSIMD_TARGET_NEON +#pragma GCC target("+simd") +#pragma clang attribute push(__attribute__((target("+simd"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_hamming_b8_neon(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, + simsimd_distance_t* result) { + simsimd_i32_t differences = 0; + simsimd_size_t i = 0; + for (; i + 16 <= n_words; i += 16) { + uint8x16_t a_first = vld1q_u8(a + i); + uint8x16_t b_first = vld1q_u8(b + i); + differences += vaddvq_u8(vcntq_u8(veorq_u8(a_first, b_first))); + } + // Handle the tail + for (; i != n_words; ++i) + differences += simsimd_popcount_b8(a[i] ^ b[i]); + *result = differences; +} + +SIMSIMD_PUBLIC void simsimd_jaccard_b8_neon(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, + simsimd_distance_t* result) { + simsimd_i32_t intersection = 0, union_ = 0; + simsimd_size_t i = 0; + for (; i + 16 <= n_words; i += 16) { + uint8x16_t a_first = vld1q_u8(a + i); + uint8x16_t b_first = vld1q_u8(b + i); + intersection += vaddvq_u8(vcntq_u8(vandq_u8(a_first, b_first))); + union_ += vaddvq_u8(vcntq_u8(vorrq_u8(a_first, b_first))); + } + // Handle the tail + for (; i != n_words; ++i) + intersection += simsimd_popcount_b8(a[i] & b[i]), union_ += simsimd_popcount_b8(a[i] | b[i]); + *result = (union_ != 0) ? 1 - (simsimd_f32_t)intersection / (simsimd_f32_t)union_ : 0; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_NEON + +#if SIMSIMD_TARGET_SVE +#pragma GCC target("+sve") +#pragma clang attribute push(__attribute__((target("+sve"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_hamming_b8_sve(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + simsimd_i32_t differences = 0; + do { + svbool_t pg_vec = svwhilelt_b8((unsigned int)i, (unsigned int)n_words); + svuint8_t a_vec = svld1_u8(pg_vec, a + i); + svuint8_t b_vec = svld1_u8(pg_vec, b + i); + differences += svaddv_u8(svptrue_b8(), svcnt_u8_x(svptrue_b8(), sveor_u8_m(svptrue_b8(), a_vec, b_vec))); + i += svcntb(); + } while (i < n_words); + *result = differences; +} + +SIMSIMD_PUBLIC void simsimd_jaccard_b8_sve(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + simsimd_i32_t intersection = 0, union_ = 0; + do { + svbool_t pg_vec = svwhilelt_b8((unsigned int)i, (unsigned int)n_words); + svuint8_t a_vec = svld1_u8(pg_vec, a + i); + svuint8_t b_vec = svld1_u8(pg_vec, b + i); + intersection += svaddv_u8(svptrue_b8(), svcnt_u8_x(svptrue_b8(), svand_u8_m(svptrue_b8(), a_vec, b_vec))); + union_ += svaddv_u8(svptrue_b8(), svcnt_u8_x(svptrue_b8(), svorr_u8_m(svptrue_b8(), a_vec, b_vec))); + i += svcntb(); + } while (i < n_words); + *result = (union_ != 0) ? 1 - (simsimd_f32_t)intersection / (simsimd_f32_t)union_ : 0; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SVE +#endif // SIMSIMD_TARGET_ARM + +#if SIMSIMD_TARGET_X86 +#if SIMSIMD_TARGET_ICE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512bw", "avx512vpopcntdq") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512bw,avx512vpopcntdq"))), \ + apply_to = function) + +SIMSIMD_PUBLIC void simsimd_hamming_b8_ice(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, + simsimd_distance_t* result) { + __m512i differences_vec = _mm512_setzero_si512(); + __m512i a_vec, b_vec; + +simsimd_hamming_b8_ice_cycle: + if (n_words < 64) { + __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words); + a_vec = _mm512_maskz_loadu_epi8(mask, a); + b_vec = _mm512_maskz_loadu_epi8(mask, b); + n_words = 0; + } else { + a_vec = _mm512_loadu_epi8(a); + b_vec = _mm512_loadu_epi8(b); + a += 64, b += 64, n_words -= 64; + } + __m512i xor_vec = _mm512_xor_si512(a_vec, b_vec); + differences_vec = _mm512_add_epi64(differences_vec, _mm512_popcnt_epi64(xor_vec)); + if (n_words) + goto simsimd_hamming_b8_ice_cycle; + + simsimd_size_t differences = _mm512_reduce_add_epi64(differences_vec); + *result = differences; +} + +SIMSIMD_PUBLIC void simsimd_jaccard_b8_ice(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, + simsimd_distance_t* result) { + __m512i intersection_vec = _mm512_setzero_si512(), union_vec = _mm512_setzero_si512(); + __m512i a_vec, b_vec; + +simsimd_jaccard_b8_ice_cycle: + if (n_words < 64) { + __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n_words); + a_vec = _mm512_maskz_loadu_epi8(mask, a); + b_vec = _mm512_maskz_loadu_epi8(mask, b); + n_words = 0; + } else { + a_vec = _mm512_loadu_epi8(a); + b_vec = _mm512_loadu_epi8(b); + a += 64, b += 64, n_words -= 64; + } + __m512i and_vec = _mm512_and_si512(a_vec, b_vec); + __m512i or_vec = _mm512_or_si512(a_vec, b_vec); + intersection_vec = _mm512_add_epi64(intersection_vec, _mm512_popcnt_epi64(and_vec)); + union_vec = _mm512_add_epi64(union_vec, _mm512_popcnt_epi64(or_vec)); + if (n_words) + goto simsimd_jaccard_b8_ice_cycle; + + simsimd_size_t intersection = _mm512_reduce_add_epi64(intersection_vec), + union_ = _mm512_reduce_add_epi64(union_vec); + *result = (union_ != 0) ? 1 - (simsimd_f32_t)intersection / (simsimd_f32_t)union_ : 0; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_ICE + +#if SIMSIMD_TARGET_HASWELL +#pragma GCC push_options +#pragma GCC target("popcnt") +#pragma clang attribute push(__attribute__((target("popcnt"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_hamming_b8_haswell(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, + simsimd_distance_t* result) { + // x86 supports unaligned loads and works just fine with the scalar version for small vectors. + simsimd_size_t differences = 0; + for (; n_words >= 8; n_words -= 8, a += 8, b += 8) + differences += _mm_popcnt_u64(*(simsimd_u64_t const*)a ^ *(simsimd_u64_t const*)b); + for (; n_words; --n_words, ++a, ++b) + differences += _mm_popcnt_u32(*a ^ *b); + *result = differences; +} + +SIMSIMD_PUBLIC void simsimd_jaccard_b8_haswell(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n_words, + simsimd_distance_t* result) { + // x86 supports unaligned loads and works just fine with the scalar version for small vectors. + simsimd_size_t intersection = 0, union_ = 0; + for (; n_words >= 8; n_words -= 8, a += 8, b += 8) + intersection += _mm_popcnt_u64(*(simsimd_u64_t const*)a & *(simsimd_u64_t const*)b), + union_ += _mm_popcnt_u64(*(simsimd_u64_t const*)a | *(simsimd_u64_t const*)b); + for (; n_words; --n_words, ++a, ++b) + intersection += _mm_popcnt_u32(*a & *b), union_ += _mm_popcnt_u32(*a | *b); + *result = (union_ != 0) ? 1 - (simsimd_f32_t)intersection / (simsimd_f32_t)union_ : 0; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_HASWELL +#endif // SIMSIMD_TARGET_X86 + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/src/include/simsimd/dot.h b/src/include/simsimd/dot.h new file mode 100644 index 0000000..6d33d16 --- /dev/null +++ b/src/include/simsimd/dot.h @@ -0,0 +1,1671 @@ +/** + * @file dot.h + * @brief SIMD-accelerated Dot Products for Real and Complex numbers. + * @author Ash Vardanian + * @date February 24, 2024 + * + * Contains: + * - Dot Product for Real and Complex vectors + * - Conjugate Dot Product for Complex vectors + * + * For datatypes: + * - 64-bit IEEE floating point numbers + * - 32-bit IEEE floating point numbers + * - 16-bit IEEE floating point numbers + * - 16-bit brain floating point numbers + * - 8-bit signed integers + * + * For hardware architectures: + * - Arm (NEON, SVE?) + * - x86 (AVX2?, AVX512?) + * + * x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/ + * Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/ + */ +#ifndef SIMSIMD_DOT_H +#define SIMSIMD_DOT_H + +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// clang-format off + +/* Serial backends for all numeric types. + * By default they use 32-bit arithmetic, unless the arguments themselves contain 64-bit floats. + * For double-precision computation check out the "*_accurate" variants of those "*_serial" functions. + */ +SIMSIMD_PUBLIC void simsimd_dot_f64_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f64c_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f64c_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_f32_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f32c_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_f16_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f16c_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_bf16_serial(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_bf16c_serial(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_serial(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); + +SIMSIMD_PUBLIC void simsimd_dot_i8_serial(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); + +/* Double-precision serial backends for all numeric types. + * For single-precision computation check out the "*_serial" counterparts of those "*_accurate" functions. + */ +SIMSIMD_PUBLIC void simsimd_dot_f32_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f32c_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_f16_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f16c_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_bf16_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_bf16c_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +/* SIMD-powered backends for Arm NEON, mostly using 32-bit arithmetic over 128-bit words. + * By far the most portable backend, covering most Arm v8 devices, over a billion phones, and almost all + * server CPUs produced before 2023. + */ +SIMSIMD_PUBLIC void simsimd_dot_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_i8_serial(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); + +/* SIMD-powered backends for Arm SVE, mostly using 32-bit arithmetic over variable-length platform-defined word sizes. + * Designed for Arm Graviton 3, Microsoft Cobalt, as well as Nvidia Grace and newer Ampere Altra CPUs. + */ +SIMSIMD_PUBLIC void simsimd_dot_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f32c_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_f16_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f16c_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +/* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer, using 32-bit arithmetic over 256-bit words. + * First demonstrated in 2011, at least one Haswell-based processor was still being sold in 2022 — the Pentium G3420. + * Practically all modern x86 CPUs support AVX2, FMA, and F16C, making it a perfect baseline for SIMD algorithms. + * On other hand, there is no need to implement AVX2 versions of `f32` and `f64` functions, as those are + * properly vectorized by recent compilers. + */ +SIMSIMD_PUBLIC void simsimd_dot_f32_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* results); + +SIMSIMD_PUBLIC void simsimd_dot_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); + +SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* result); + +/* SIMD-powered backends for various generations of AVX512 CPUs. + * Skylake is handy, as it supports masked loads and other operations, avoiding the need for the tail loop. + * Ice Lake added VNNI, VPOPCNTDQ, IFMA, VBMI, VAES, GFNI, VBMI2, BITALG, VPCLMULQDQ, and other extensions for integral operations. + * Genoa added only BF16. + * Sapphire Rapids added tiled matrix operations, but we are most interested in the new mixed-precision FMA instructions. + */ +SIMSIMD_PUBLIC void simsimd_dot_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* result); + +SIMSIMD_PUBLIC void simsimd_dot_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* result); + +SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t*); + +SIMSIMD_PUBLIC void simsimd_dot_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* result); + +SIMSIMD_PUBLIC void simsimd_dot_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* result); +// clang-format on + +#define SIMSIMD_MAKE_DOT(name, input_type, accumulator_type, converter) \ + SIMSIMD_PUBLIC void simsimd_dot_##input_type##_##name(simsimd_##input_type##_t const* a, \ + simsimd_##input_type##_t const* b, simsimd_size_t n, \ + simsimd_distance_t* result) { \ + simsimd_##accumulator_type##_t ab = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = converter(a[i]); \ + simsimd_##accumulator_type##_t bi = converter(b[i]); \ + ab += ai * bi; \ + } \ + *result = ab; \ + } + +#define SIMSIMD_MAKE_COMPLEX_DOT(name, input_type, accumulator_type, converter) \ + SIMSIMD_PUBLIC void simsimd_dot_##input_type##c_##name(simsimd_##input_type##_t const* a, \ + simsimd_##input_type##_t const* b, simsimd_size_t n, \ + simsimd_distance_t* results) { \ + simsimd_##accumulator_type##_t ab_real = 0, ab_imag = 0; \ + for (simsimd_size_t i = 0; i + 2 <= n; i += 2) { \ + simsimd_##accumulator_type##_t ar = converter(a[i]); \ + simsimd_##accumulator_type##_t br = converter(b[i]); \ + simsimd_##accumulator_type##_t ai = converter(a[i + 1]); \ + simsimd_##accumulator_type##_t bi = converter(b[i + 1]); \ + ab_real += ar * br - ai * bi; \ + ab_imag += ar * bi + ai * br; \ + } \ + results[0] = ab_real; \ + results[1] = ab_imag; \ + } + +#define SIMSIMD_MAKE_COMPLEX_VDOT(name, input_type, accumulator_type, converter) \ + SIMSIMD_PUBLIC void simsimd_vdot_##input_type##c_##name(simsimd_##input_type##_t const* a, \ + simsimd_##input_type##_t const* b, simsimd_size_t n, \ + simsimd_distance_t* results) { \ + simsimd_##accumulator_type##_t ab_real = 0, ab_imag = 0; \ + for (simsimd_size_t i = 0; i + 2 <= n; i += 2) { \ + simsimd_##accumulator_type##_t ar = converter(a[i]); \ + simsimd_##accumulator_type##_t br = converter(b[i]); \ + simsimd_##accumulator_type##_t ai = converter(a[i + 1]); \ + simsimd_##accumulator_type##_t bi = converter(b[i + 1]); \ + ab_real += ar * br + ai * bi; \ + ab_imag += ar * bi - ai * br; \ + } \ + results[0] = ab_real; \ + results[1] = ab_imag; \ + } + +SIMSIMD_MAKE_DOT(serial, f64, f64, SIMSIMD_IDENTIFY) // simsimd_dot_f64_serial +SIMSIMD_MAKE_COMPLEX_DOT(serial, f64, f64, SIMSIMD_IDENTIFY) // simsimd_dot_f64c_serial +SIMSIMD_MAKE_COMPLEX_VDOT(serial, f64, f64, SIMSIMD_IDENTIFY) // simsimd_vdot_f64c_serial + +SIMSIMD_MAKE_DOT(serial, f32, f32, SIMSIMD_IDENTIFY) // simsimd_dot_f32_serial +SIMSIMD_MAKE_COMPLEX_DOT(serial, f32, f32, SIMSIMD_IDENTIFY) // simsimd_dot_f32c_serial +SIMSIMD_MAKE_COMPLEX_VDOT(serial, f32, f32, SIMSIMD_IDENTIFY) // simsimd_vdot_f32c_serial + +SIMSIMD_MAKE_DOT(serial, f16, f32, SIMSIMD_UNCOMPRESS_F16) // simsimd_dot_f16_serial +SIMSIMD_MAKE_COMPLEX_DOT(serial, f16, f32, SIMSIMD_UNCOMPRESS_F16) // simsimd_dot_f16c_serial +SIMSIMD_MAKE_COMPLEX_VDOT(serial, f16, f32, SIMSIMD_UNCOMPRESS_F16) // simsimd_vdot_f16c_serial + +SIMSIMD_MAKE_DOT(serial, bf16, f32, SIMSIMD_UNCOMPRESS_BF16) // simsimd_dot_bf16_serial +SIMSIMD_MAKE_COMPLEX_DOT(serial, bf16, f32, SIMSIMD_UNCOMPRESS_BF16) // simsimd_dot_bf16c_serial +SIMSIMD_MAKE_COMPLEX_VDOT(serial, bf16, f32, SIMSIMD_UNCOMPRESS_BF16) // simsimd_vdot_bf16c_serial + +SIMSIMD_MAKE_DOT(serial, i8, i64, SIMSIMD_IDENTIFY) // simsimd_dot_i8_serial + +SIMSIMD_MAKE_DOT(accurate, f32, f64, SIMSIMD_IDENTIFY) // simsimd_dot_f32_accurate +SIMSIMD_MAKE_COMPLEX_DOT(accurate, f32, f64, SIMSIMD_IDENTIFY) // simsimd_dot_f32c_accurate +SIMSIMD_MAKE_COMPLEX_VDOT(accurate, f32, f64, SIMSIMD_IDENTIFY) // simsimd_vdot_f32c_accurate + +SIMSIMD_MAKE_DOT(accurate, f16, f64, SIMSIMD_UNCOMPRESS_F16) // simsimd_dot_f16_accurate +SIMSIMD_MAKE_COMPLEX_DOT(accurate, f16, f64, SIMSIMD_UNCOMPRESS_F16) // simsimd_dot_f16c_accurate +SIMSIMD_MAKE_COMPLEX_VDOT(accurate, f16, f64, SIMSIMD_UNCOMPRESS_F16) // simsimd_vdot_f16c_accurate + +SIMSIMD_MAKE_DOT(accurate, bf16, f64, SIMSIMD_UNCOMPRESS_BF16) // simsimd_dot_bf16_accurate +SIMSIMD_MAKE_COMPLEX_DOT(accurate, bf16, f64, SIMSIMD_UNCOMPRESS_BF16) // simsimd_dot_bf16c_accurate +SIMSIMD_MAKE_COMPLEX_VDOT(accurate, bf16, f64, SIMSIMD_UNCOMPRESS_BF16) // simsimd_vdot_bf16c_accurate + +#if SIMSIMD_TARGET_ARM +#if SIMSIMD_TARGET_NEON +#pragma GCC push_options +#pragma GCC target("+simd") +#pragma clang attribute push(__attribute__((target("+simd"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t ab_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vld1q_f32(a + i); + float32x4_t b_vec = vld1q_f32(b + i); + ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec); + } + simsimd_f32_t ab = vaddvq_f32(ab_vec); + for (; i < n; ++i) + ab += a[i] * b[i]; + *result = ab; +} + +SIMSIMD_PUBLIC void simsimd_dot_f32c_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, // + simsimd_distance_t* results) { + float32x4_t ab_real_vec = vdupq_n_f32(0); + float32x4_t ab_imag_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + // Unpack the input arrays into real and imaginary parts: + float32x4x2_t a_vec = vld2q_f32(a + i); + float32x4x2_t b_vec = vld2q_f32(b + i); + float32x4_t a_real_vec = a_vec.val[0]; + float32x4_t a_imag_vec = a_vec.val[1]; + float32x4_t b_real_vec = b_vec.val[0]; + float32x4_t b_imag_vec = b_vec.val[1]; + + // Compute the dot product: + ab_real_vec = vfmaq_f32(ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = vfmsq_f32(ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = vfmaq_f32(ab_imag_vec, a_imag_vec, b_real_vec); + } + + // Reduce horizontal sums: + simsimd_f32_t ab_real = vaddvq_f32(ab_real_vec); + simsimd_f32_t ab_imag = vaddvq_f32(ab_imag_vec); + + // Handle the tail: + for (; i + 2 <= n; i += 2) { + simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + ab_real += ar * br - ai * bi; + ab_imag += ar * bi + ai * br; + } + results[0] = ab_real; + results[1] = ab_imag; +} + +SIMSIMD_PUBLIC void simsimd_vdot_f32c_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, // + simsimd_distance_t* results) { + float32x4_t ab_real_vec = vdupq_n_f32(0); + float32x4_t ab_imag_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + // Unpack the input arrays into real and imaginary parts: + float32x4x2_t a_vec = vld2q_f32(a + i); + float32x4x2_t b_vec = vld2q_f32(b + i); + float32x4_t a_real_vec = a_vec.val[0]; + float32x4_t a_imag_vec = a_vec.val[1]; + float32x4_t b_real_vec = b_vec.val[0]; + float32x4_t b_imag_vec = b_vec.val[1]; + + // Compute the dot product: + ab_real_vec = vfmaq_f32(ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = vfmaq_f32(ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = vfmsq_f32(ab_imag_vec, a_imag_vec, b_real_vec); + } + + // Reduce horizontal sums: + simsimd_f32_t ab_real = vaddvq_f32(ab_real_vec); + simsimd_f32_t ab_imag = vaddvq_f32(ab_imag_vec); + + // Handle the tail: + for (; i + 2 <= n; i += 2) { + simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + ab_real += ar * br + ai * bi; + ab_imag += ar * bi - ai * br; + } + results[0] = ab_real; + results[1] = ab_imag; +} + +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+dotprod") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+dotprod"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + int32x4_t ab_vec = vdupq_n_s32(0); + simsimd_size_t i = 0; + + // If the 128-bit `vdot_s32` intrinsic is unavailable, we can use the 64-bit `vdot_s32`. + // for (simsimd_size_t i = 0; i != n; i += 8) { + // int16x8_t a_vec = vmovl_s8(vld1_s8(a + i)); + // int16x8_t b_vec = vmovl_s8(vld1_s8(b + i)); + // int16x8_t ab_part_vec = vmulq_s16(a_vec, b_vec); + // ab_vec = vaddq_s32(ab_vec, vaddq_s32(vmovl_s16(vget_high_s16(ab_part_vec)), // + // vmovl_s16(vget_low_s16(ab_part_vec)))); + // } + for (; i + 16 <= n; i += 16) { + int8x16_t a_vec = vld1q_s8(a + i); + int8x16_t b_vec = vld1q_s8(b + i); + ab_vec = vdotq_s32(ab_vec, a_vec, b_vec); + } + + // Take care of the tail: + int32_t ab = vaddvq_s32(ab_vec); + for (; i < n; ++i) { + int32_t ai = a[i], bi = b[i]; + ab += ai * bi; + } + + *result = ab; +} + +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("+simd+fp16") +#pragma clang attribute push(__attribute__((target("+simd+fp16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t ab_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a + i)); + float32x4_t b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b + i)); + ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec); + } + + // In case the software emulation for `f16` scalars is enabled, the `simsimd_uncompress_f16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + float16x4_t f16_vec; + simsimd_f16_t f16[4]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.f16[j] = a[i], b_padded_tail.f16[j] = b[i]; + for (; j < 4; ++j) + a_padded_tail.f16[j] = 0, b_padded_tail.f16[j] = 0; + ab_vec = vfmaq_f32(ab_vec, vcvt_f32_f16(a_padded_tail.f16_vec), vcvt_f32_f16(b_padded_tail.f16_vec)); + } + *result = vaddvq_f32(ab_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, // + simsimd_distance_t* results) { + + // A nicer approach is to use `f16` arithmetic for the dot product, but that requires + // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries + // at once. That's how the original implementation worked, but compiling it was a nightmare :) + float32x4_t ab_real_vec = vdupq_n_f32(0); + float32x4_t ab_imag_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + // Unpack the input arrays into real and imaginary parts. + // MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed + // integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards. + int16x4x2_t a_vec = vld2_s16((short*)a + i); + int16x4x2_t b_vec = vld2_s16((short*)b + i); + float32x4_t a_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[0])); + float32x4_t a_imag_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[1])); + float32x4_t b_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(b_vec.val[0])); + float32x4_t b_imag_vec = vcvt_f32_f16(vreinterpret_f16_s16(b_vec.val[1])); + + // Compute the dot product: + ab_real_vec = vfmaq_f32(ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = vfmsq_f32(ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = vfmaq_f32(ab_imag_vec, a_imag_vec, b_real_vec); + } + + // Reduce horizontal sums: + simsimd_f32_t ab_real = vaddvq_f32(ab_real_vec); + simsimd_f32_t ab_imag = vaddvq_f32(ab_imag_vec); + + // Handle the tail: + for (; i + 2 <= n; i += 2) { + simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + ab_real += ar * br - ai * bi; + ab_imag += ar * bi + ai * br; + } + results[0] = ab_real; + results[1] = ab_imag; +} + +SIMSIMD_PUBLIC void simsimd_vdot_f16c_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, // + simsimd_distance_t* results) { + + // A nicer approach is to use `f16` arithmetic for the dot product, but that requires + // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries + // at once. That's how the original implementation worked, but compiling it was a nightmare :) + float32x4_t ab_real_vec = vdupq_n_f32(0); + float32x4_t ab_imag_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + // Unpack the input arrays into real and imaginary parts. + // MSVC sadly doesn't recognize the `vld2_f16`, so we load the data as signed + // integers of the same size and reinterpret with `vreinterpret_f16_s16` afterwards. + int16x4x2_t a_vec = vld2_s16((short*)a + i); + int16x4x2_t b_vec = vld2_s16((short*)b + i); + float32x4_t a_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[0])); + float32x4_t a_imag_vec = vcvt_f32_f16(vreinterpret_f16_s16(a_vec.val[1])); + float32x4_t b_real_vec = vcvt_f32_f16(vreinterpret_f16_s16(b_vec.val[0])); + float32x4_t b_imag_vec = vcvt_f32_f16(vreinterpret_f16_s16(b_vec.val[1])); + + // Compute the dot product: + ab_real_vec = vfmaq_f32(ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = vfmaq_f32(ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = vfmsq_f32(ab_imag_vec, a_imag_vec, b_real_vec); + } + + // Reduce horizontal sums: + simsimd_f32_t ab_real = vaddvq_f32(ab_real_vec); + simsimd_f32_t ab_imag = vaddvq_f32(ab_imag_vec); + + // Handle the tail: + for (; i + 2 <= n; i += 2) { + simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + ab_real += ar * br + ai * bi; + ab_imag += ar * bi - ai * br; + } + results[0] = ab_real; + results[1] = ab_imag; +} + +#pragma clang attribute pop +#pragma GCC pop_options + +#if SIMSIMD_TARGET_NEON_BF16_IMPLEMENTED +#pragma GCC push_options +#pragma GCC target("+simd+bf16") +#pragma clang attribute push(__attribute__((target("+simd+bf16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t ab_high_vec = vdupq_n_f32(0), ab_low_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + bfloat16x8_t a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)a + i); + bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)b + i); + ab_high_vec = vbfmlaltq_f32(ab_high_vec, a_vec, b_vec); + ab_low_vec = vbfmlalbq_f32(ab_low_vec, a_vec, b_vec); + } + + // In case the software emulation for `bf16` scalars is enabled, the `simsimd_uncompress_bf16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + bfloat16x8_t bf16_vec; + simsimd_bf16_t bf16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.bf16[j] = a[i], b_padded_tail.bf16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.bf16[j] = 0, b_padded_tail.bf16[j] = 0; + ab_high_vec = vbfmlaltq_f32(ab_high_vec, a_padded_tail.bf16_vec, b_padded_tail.bf16_vec); + ab_low_vec = vbfmlalbq_f32(ab_low_vec, a_padded_tail.bf16_vec, b_padded_tail.bf16_vec); + } + *result = vaddvq_f32(ab_high_vec) + vaddvq_f32(ab_low_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, // + simsimd_distance_t* results) { + + // A nicer approach is to use `bf16` arithmetic for the dot product, but that requires + // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries + // at once. That's how the original implementation worked, but compiling it was a nightmare :) + float32x4_t ab_real_vec = vdupq_n_f32(0); + float32x4_t ab_imag_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + // Unpack the input arrays into real and imaginary parts. + // MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed + // integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards. + int16x4x2_t a_vec = vld2_s16((short*)a + i); + int16x4x2_t b_vec = vld2_s16((short*)b + i); + float32x4_t a_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[0])); + float32x4_t a_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[1])); + float32x4_t b_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[0])); + float32x4_t b_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[1])); + + // Compute the dot product: + ab_real_vec = vfmaq_f32(ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = vfmsq_f32(ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = vfmaq_f32(ab_imag_vec, a_imag_vec, b_real_vec); + } + + // Reduce horizontal sums: + simsimd_f32_t ab_real = vaddvq_f32(ab_real_vec); + simsimd_f32_t ab_imag = vaddvq_f32(ab_imag_vec); + + // Handle the tail: + for (; i + 2 <= n; i += 2) { + simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + ab_real += ar * br - ai * bi; + ab_imag += ar * bi + ai * br; + } + results[0] = ab_real; + results[1] = ab_imag; +} + +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, // + simsimd_distance_t* results) { + + // A nicer approach is to use `bf16` arithmetic for the dot product, but that requires + // FMLA extensions available on Arm v8.3 and later. That we can also process 16 entries + // at once. That's how the original implementation worked, but compiling it was a nightmare :) + float32x4_t ab_real_vec = vdupq_n_f32(0); + float32x4_t ab_imag_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + // Unpack the input arrays into real and imaginary parts. + // MSVC sadly doesn't recognize the `vld2_bf16`, so we load the data as signed + // integers of the same size and reinterpret with `vreinterpret_bf16_s16` afterwards. + int16x4x2_t a_vec = vld2_s16((short*)a + i); + int16x4x2_t b_vec = vld2_s16((short*)b + i); + float32x4_t a_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[0])); + float32x4_t a_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(a_vec.val[1])); + float32x4_t b_real_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[0])); + float32x4_t b_imag_vec = vcvt_f32_bf16(vreinterpret_bf16_s16(b_vec.val[1])); + + // Compute the dot product: + ab_real_vec = vfmaq_f32(ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = vfmaq_f32(ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = vfmaq_f32(ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = vfmsq_f32(ab_imag_vec, a_imag_vec, b_real_vec); + } + + // Reduce horizontal sums: + simsimd_f32_t ab_real = vaddvq_f32(ab_real_vec); + simsimd_f32_t ab_imag = vaddvq_f32(ab_imag_vec); + + // Handle the tail: + for (; i + 2 <= n; i += 2) { + simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + ab_real += ar * br + ai * bi; + ab_imag += ar * bi - ai * br; + } + results[0] = ab_real; + results[1] = ab_imag; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif +#endif // SIMSIMD_TARGET_NEON + +#if SIMSIMD_TARGET_SVE + +#pragma GCC push_options +#pragma GCC target("+sve") +#pragma clang attribute push(__attribute__((target("+sve"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + svfloat32_t ab_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); + do { + svbool_t pg_vec = svwhilelt_b32((unsigned int)i, (unsigned int)n); + svfloat32_t a_vec = svld1_f32(pg_vec, a + i); + svfloat32_t b_vec = svld1_f32(pg_vec, b + i); + ab_vec = svmla_f32_x(pg_vec, ab_vec, a_vec, b_vec); + i += svcntw(); + } while (i < n); + *result = svaddv_f32(svptrue_b32(), ab_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_f32c_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + simsimd_size_t i = 0; + svfloat32_t ab_real_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); + svfloat32_t ab_imag_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); + do { + svbool_t pg_vec = svwhilelt_b32((unsigned int)i, (unsigned int)n); + svfloat32x2_t a_vec = svld2_f32(pg_vec, a + i); + svfloat32x2_t b_vec = svld2_f32(pg_vec, b + i); + svfloat32_t a_real_vec = svget2_f32(a_vec, 0); + svfloat32_t a_imag_vec = svget2_f32(a_vec, 1); + svfloat32_t b_real_vec = svget2_f32(b_vec, 0); + svfloat32_t b_imag_vec = svget2_f32(b_vec, 1); + ab_real_vec = svmla_f32_x(pg_vec, ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = svmls_f32_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = svmla_f32_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = svmla_f32_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); + i += svcntw() * 2; + } while (i < n); + results[0] = svaddv_f32(svptrue_b32(), ab_real_vec); + results[1] = svaddv_f32(svptrue_b32(), ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_vdot_f32c_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + simsimd_size_t i = 0; + svfloat32_t ab_real_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); + svfloat32_t ab_imag_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); + do { + svbool_t pg_vec = svwhilelt_b32((unsigned int)i, (unsigned int)n); + svfloat32x2_t a_vec = svld2_f32(pg_vec, a + i); + svfloat32x2_t b_vec = svld2_f32(pg_vec, b + i); + svfloat32_t a_real_vec = svget2_f32(a_vec, 0); + svfloat32_t a_imag_vec = svget2_f32(a_vec, 1); + svfloat32_t b_real_vec = svget2_f32(b_vec, 0); + svfloat32_t b_imag_vec = svget2_f32(b_vec, 1); + ab_real_vec = svmla_f32_x(pg_vec, ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = svmla_f32_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = svmla_f32_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = svmls_f32_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); + i += svcntw() * 2; + } while (i < n); + results[0] = svaddv_f32(svptrue_b32(), ab_real_vec); + results[1] = svaddv_f32(svptrue_b32(), ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + svfloat64_t ab_vec = svdupq_n_f64(0.0, 0.0); + do { + svbool_t pg_vec = svwhilelt_b64((unsigned int)i, (unsigned int)n); + svfloat64_t a_vec = svld1_f64(pg_vec, a + i); + svfloat64_t b_vec = svld1_f64(pg_vec, b + i); + ab_vec = svmla_f64_x(pg_vec, ab_vec, a_vec, b_vec); + i += svcntd(); + } while (i < n); + *result = svaddv_f64(svptrue_b32(), ab_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + simsimd_size_t i = 0; + svfloat64_t ab_real_vec = svdupq_n_f64(0., 0.); + svfloat64_t ab_imag_vec = svdupq_n_f64(0., 0.); + do { + svbool_t pg_vec = svwhilelt_b64((unsigned int)i, (unsigned int)n); + svfloat64x2_t a_vec = svld2_f64(pg_vec, a + i); + svfloat64x2_t b_vec = svld2_f64(pg_vec, b + i); + svfloat64_t a_real_vec = svget2_f64(a_vec, 0); + svfloat64_t a_imag_vec = svget2_f64(a_vec, 1); + svfloat64_t b_real_vec = svget2_f64(b_vec, 0); + svfloat64_t b_imag_vec = svget2_f64(b_vec, 1); + ab_real_vec = svmla_f64_x(pg_vec, ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = svmls_f64_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = svmla_f64_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = svmla_f64_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); + i += svcntd() * 2; + } while (i < n); + results[0] = svaddv_f64(svptrue_b64(), ab_real_vec); + results[1] = svaddv_f64(svptrue_b64(), ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_vdot_f64c_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + simsimd_size_t i = 0; + svfloat64_t ab_real_vec = svdupq_n_f64(0., 0.); + svfloat64_t ab_imag_vec = svdupq_n_f64(0., 0.); + do { + svbool_t pg_vec = svwhilelt_b64((unsigned int)i, (unsigned int)n); + svfloat64x2_t a_vec = svld2_f64(pg_vec, a + i); + svfloat64x2_t b_vec = svld2_f64(pg_vec, b + i); + svfloat64_t a_real_vec = svget2_f64(a_vec, 0); + svfloat64_t a_imag_vec = svget2_f64(a_vec, 1); + svfloat64_t b_real_vec = svget2_f64(b_vec, 0); + svfloat64_t b_imag_vec = svget2_f64(b_vec, 1); + ab_real_vec = svmla_f64_x(pg_vec, ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = svmla_f64_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = svmla_f64_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = svmls_f64_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); + i += svcntd() * 2; + } while (i < n); + results[0] = svaddv_f64(svptrue_b64(), ab_real_vec); + results[1] = svaddv_f64(svptrue_b64(), ab_imag_vec); +} + +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("+sve+fp16") +#pragma clang attribute push(__attribute__((target("+sve+fp16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_f16_sve(simsimd_f16_t const* a_enum, simsimd_f16_t const* b_enum, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + svfloat16_t ab_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); + simsimd_f16_for_arm_simd_t const* a = (simsimd_f16_for_arm_simd_t const*)(a_enum); + simsimd_f16_for_arm_simd_t const* b = (simsimd_f16_for_arm_simd_t const*)(b_enum); + do { + svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); + svfloat16_t a_vec = svld1_f16(pg_vec, a + i); + svfloat16_t b_vec = svld1_f16(pg_vec, b + i); + ab_vec = svmla_f16_x(pg_vec, ab_vec, a_vec, b_vec); + i += svcnth(); + } while (i < n); + simsimd_f16_for_arm_simd_t ab = svaddv_f16(svptrue_b16(), ab_vec); + *result = ab; +} + +SIMSIMD_PUBLIC void simsimd_dot_f16c_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + simsimd_size_t i = 0; + svfloat16_t ab_real_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); + svfloat16_t ab_imag_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); + do { + svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); + svfloat16x2_t a_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const*)a + i); + svfloat16x2_t b_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const*)b + i); + svfloat16_t a_real_vec = svget2_f16(a_vec, 0); + svfloat16_t a_imag_vec = svget2_f16(a_vec, 1); + svfloat16_t b_real_vec = svget2_f16(b_vec, 0); + svfloat16_t b_imag_vec = svget2_f16(b_vec, 1); + ab_real_vec = svmla_f16_x(pg_vec, ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = svmls_f16_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = svmla_f16_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = svmla_f16_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); + i += svcnth() * 2; + } while (i < n); + results[0] = svaddv_f16(svptrue_b16(), ab_real_vec); + results[1] = svaddv_f16(svptrue_b16(), ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_vdot_f16c_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + simsimd_size_t i = 0; + svfloat16_t ab_real_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); + svfloat16_t ab_imag_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); + do { + svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); + svfloat16x2_t a_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const*)a + i); + svfloat16x2_t b_vec = svld2_f16(pg_vec, (simsimd_f16_for_arm_simd_t const*)b + i); + svfloat16_t a_real_vec = svget2_f16(a_vec, 0); + svfloat16_t a_imag_vec = svget2_f16(a_vec, 1); + svfloat16_t b_real_vec = svget2_f16(b_vec, 0); + svfloat16_t b_imag_vec = svget2_f16(b_vec, 1); + ab_real_vec = svmla_f16_x(pg_vec, ab_real_vec, a_real_vec, b_real_vec); + ab_real_vec = svmla_f16_x(pg_vec, ab_real_vec, a_imag_vec, b_imag_vec); + ab_imag_vec = svmla_f16_x(pg_vec, ab_imag_vec, a_real_vec, b_imag_vec); + ab_imag_vec = svmls_f16_x(pg_vec, ab_imag_vec, a_imag_vec, b_real_vec); + i += svcnth() * 2; + } while (i < n); + results[0] = svaddv_f16(svptrue_b16(), ab_real_vec); + results[1] = svaddv_f16(svptrue_b16(), ab_imag_vec); +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SVE +#endif // SIMSIMD_TARGET_ARM + +#if SIMSIMD_TARGET_X86 +#if SIMSIMD_TARGET_HASWELL +#pragma GCC push_options +#pragma GCC target("avx2", "f16c", "fma") +#pragma clang attribute push(__attribute__((target("avx2,f16c,fma"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m256 ab_vec = _mm256_setzero_ps(); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i))); + __m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i))); + ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); + } + + // In case the software emulation for `f16` scalars is enabled, the `simsimd_uncompress_f16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + __m128i f16_vec; + simsimd_f16_t f16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.f16[j] = a[i], b_padded_tail.f16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.f16[j] = 0, b_padded_tail.f16[j] = 0; + __m256 a_vec = _mm256_cvtph_ps(a_padded_tail.f16_vec); + __m256 b_vec = _mm256_cvtph_ps(b_padded_tail.f16_vec); + ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); + } + + ab_vec = _mm256_add_ps(_mm256_permute2f128_ps(ab_vec, ab_vec, 1), ab_vec); + ab_vec = _mm256_hadd_ps(ab_vec, ab_vec); + ab_vec = _mm256_hadd_ps(ab_vec, ab_vec); + + simsimd_f32_t f32_result; + _mm_store_ss(&f32_result, _mm256_castps256_ps128(ab_vec)); + *result = f32_result; +} + +inline simsimd_f64_t _mm256_reduce_add_ps_dbl(__m256 vec) { + // Convert the lower and higher 128-bit lanes of the input vector to double precision + __m128 low_f32 = _mm256_castps256_ps128(vec); + __m128 high_f32 = _mm256_extractf128_ps(vec, 1); + + // Convert single-precision (float) vectors to double-precision (double) vectors + __m256d low_f64 = _mm256_cvtps_pd(low_f32); + __m256d high_f64 = _mm256_cvtps_pd(high_f32); + + // Perform the addition in double-precision + __m256d sum = _mm256_add_pd(low_f64, high_f64); + + // Reduce the double-precision vector to a scalar + // Horizontal add the first and second double-precision values, and third and fourth + __m128d sum_low = _mm256_castpd256_pd128(sum); + __m128d sum_high = _mm256_extractf128_pd(sum, 1); + __m128d sum128 = _mm_add_pd(sum_low, sum_high); + + // Horizontal add again to accumulate all four values into one + sum128 = _mm_hadd_pd(sum128, sum128); + + // Convert the final sum to a scalar double-precision value and return + return _mm_cvtsd_f64(sum128); +} + +SIMSIMD_PUBLIC void simsimd_dot_f32c_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + // The naive approach would be to use FMA and FMS instructions on different parts of the vectors. + // Prior to that we would need to shuffle the input vectors to separate real and imaginary parts. + // Both operations are quite expensive, and the resulting kernel would run at 2.5 GB/s. + // __m128 ab_real_vec = _mm_setzero_ps(); + // __m128 ab_imag_vec = _mm_setzero_ps(); + // __m256i permute_vec = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); + // simsimd_size_t i = 0; + // for (; i + 8 <= n; i += 8) { + // __m256 a_vec = _mm256_loadu_ps(a + i); + // __m256 b_vec = _mm256_loadu_ps(b + i); + // __m256 a_shuffled = _mm256_permutevar8x32_ps(a_vec, permute_vec); + // __m256 b_shuffled = _mm256_permutevar8x32_ps(b_vec, permute_vec); + // __m128 a_real_vec = _mm256_extractf128_ps(a_shuffled, 0); + // __m128 a_imag_vec = _mm256_extractf128_ps(a_shuffled, 1); + // __m128 b_real_vec = _mm256_extractf128_ps(b_shuffled, 0); + // __m128 b_imag_vec = _mm256_extractf128_ps(b_shuffled, 1); + // ab_real_vec = _mm_fmadd_ps(a_real_vec, b_real_vec, ab_real_vec); + // ab_real_vec = _mm_fnmadd_ps(a_imag_vec, b_imag_vec, ab_real_vec); + // ab_imag_vec = _mm_fmadd_ps(a_real_vec, b_imag_vec, ab_imag_vec); + // ab_imag_vec = _mm_fmadd_ps(a_imag_vec, b_real_vec, ab_imag_vec); + // } + // + // Instead, we take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + // Both operations are quite cheap, and the throughput doubles from 2.5 GB/s to 5 GB/s. + __m256 ab_real_vec = _mm256_setzero_ps(); + __m256 ab_imag_vec = _mm256_setzero_ps(); + __m256i sign_flip_vec = _mm256_set1_epi64x(0x8000000000000000); + __m256i swap_adjacent_vec = _mm256_set_epi8( // + 11, 10, 9, 8, // Points to the third f32 in 128-bit lane + 15, 14, 13, 12, // Points to the fourth f32 in 128-bit lane + 3, 2, 1, 0, // Points to the first f32 in 128-bit lane + 7, 6, 5, 4, // Points to the second f32 in 128-bit lane + 11, 10, 9, 8, // Points to the third f32 in 128-bit lane + 15, 14, 13, 12, // Points to the fourth f32 in 128-bit lane + 3, 2, 1, 0, // Points to the first f32 in 128-bit lane + 7, 6, 5, 4 // Points to the second f32 in 128-bit lane + ); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_loadu_ps(a + i); + __m256 b_vec = _mm256_loadu_ps(b + i); + __m256 b_flipped_vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(b_vec), sign_flip_vec)); + __m256 b_swapped_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); + ab_real_vec = _mm256_fmadd_ps(a_vec, b_flipped_vec, ab_real_vec); + ab_imag_vec = _mm256_fmadd_ps(a_vec, b_swapped_vec, ab_imag_vec); + } + + // Reduce horizontal sums: + simsimd_distance_t ab_real = _mm256_reduce_add_ps_dbl(ab_real_vec); + simsimd_distance_t ab_imag = _mm256_reduce_add_ps_dbl(ab_imag_vec); + + // Handle the tail: + for (; i + 2 <= n; i += 2) { + simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + ab_real += ar * br - ai * bi; + ab_imag += ar * bi + ai * br; + } + results[0] = ab_real; + results[1] = ab_imag; +} + +SIMSIMD_PUBLIC void simsimd_vdot_f32c_haswell(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + __m256 ab_real_vec = _mm256_setzero_ps(); + __m256 ab_imag_vec = _mm256_setzero_ps(); + __m256i sign_flip_vec = _mm256_set1_epi64x(0x8000000000000000); + __m256i swap_adjacent_vec = _mm256_set_epi8( // + 11, 10, 9, 8, // Points to the third f32 in 128-bit lane + 15, 14, 13, 12, // Points to the fourth f32 in 128-bit lane + 3, 2, 1, 0, // Points to the first f32 in 128-bit lane + 7, 6, 5, 4, // Points to the second f32 in 128-bit lane + 11, 10, 9, 8, // Points to the third f32 in 128-bit lane + 15, 14, 13, 12, // Points to the fourth f32 in 128-bit lane + 3, 2, 1, 0, // Points to the first f32 in 128-bit lane + 7, 6, 5, 4 // Points to the second f32 in 128-bit lane + ); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_loadu_ps(a + i); + __m256 b_vec = _mm256_loadu_ps(b + i); + ab_real_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_real_vec); + a_vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(a_vec), sign_flip_vec)); + b_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); + ab_imag_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_imag_vec); + } + + // Reduce horizontal sums: + simsimd_distance_t ab_real = _mm256_reduce_add_ps_dbl(ab_real_vec); + simsimd_distance_t ab_imag = _mm256_reduce_add_ps_dbl(ab_imag_vec); + + // Handle the tail: + for (; i + 2 <= n; i += 2) { + simsimd_f32_t ar = a[i], ai = a[i + 1], br = b[i], bi = b[i + 1]; + ab_real += ar * br + ai * bi; + ab_imag += ar * bi - ai * br; + } + results[0] = ab_real; + results[1] = ab_imag; +} + +SIMSIMD_PUBLIC void simsimd_dot_f16c_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + // Ideally the implementation would load 256 bits worth of vector data at a time, + // shuffle those within a register, split in halfs, and only then upcast. + // That way, we are stepping through 32x 16-bit vector components at a time, or 16 dimensions. + // Sadly, shuffling 16-bit entries in a YMM register is hard to implement efficiently. + // + // Simpler approach is to load 128 bits at a time, upcast, and then shuffle. + // This mostly replicates the `simsimd_dot_f32c_haswell`. + __m256 ab_real_vec = _mm256_setzero_ps(); + __m256 ab_imag_vec = _mm256_setzero_ps(); + __m256i sign_flip_vec = _mm256_set1_epi64x(0x8000000000000000); + __m256i swap_adjacent_vec = _mm256_set_epi8( // + 11, 10, 9, 8, // Points to the third f32 in 128-bit lane + 15, 14, 13, 12, // Points to the fourth f32 in 128-bit lane + 3, 2, 1, 0, // Points to the first f32 in 128-bit lane + 7, 6, 5, 4, // Points to the second f32 in 128-bit lane + 11, 10, 9, 8, // Points to the third f32 in 128-bit lane + 15, 14, 13, 12, // Points to the fourth f32 in 128-bit lane + 3, 2, 1, 0, // Points to the first f32 in 128-bit lane + 7, 6, 5, 4 // Points to the second f32 in 128-bit lane + ); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i))); + __m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i))); + __m256 b_flipped_vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(b_vec), sign_flip_vec)); + __m256 b_swapped_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); + ab_real_vec = _mm256_fmadd_ps(a_vec, b_flipped_vec, ab_real_vec); + ab_imag_vec = _mm256_fmadd_ps(a_vec, b_swapped_vec, ab_imag_vec); + } + + // In case the software emulation for `f16` scalars is enabled, the `simsimd_uncompress_f16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + __m128i f16_vec; + simsimd_f16_t f16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.f16[j] = a[i], b_padded_tail.f16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.f16[j] = 0, b_padded_tail.f16[j] = 0; + __m256 a_vec = _mm256_cvtph_ps(a_padded_tail.f16_vec); + __m256 b_vec = _mm256_cvtph_ps(b_padded_tail.f16_vec); + __m256 b_flipped_vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(b_vec), sign_flip_vec)); + __m256 b_swapped_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); + ab_real_vec = _mm256_fmadd_ps(a_vec, b_flipped_vec, ab_real_vec); + ab_imag_vec = _mm256_fmadd_ps(a_vec, b_swapped_vec, ab_imag_vec); + } + + // Reduce horizontal sums: + results[0] = _mm256_reduce_add_ps_dbl(ab_real_vec); + results[1] = _mm256_reduce_add_ps_dbl(ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_vdot_f16c_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + __m256 ab_real_vec = _mm256_setzero_ps(); + __m256 ab_imag_vec = _mm256_setzero_ps(); + __m256i sign_flip_vec = _mm256_set1_epi64x(0x8000000000000000); + __m256i swap_adjacent_vec = _mm256_set_epi8( // + 11, 10, 9, 8, // Points to the third f32 in 128-bit lane + 15, 14, 13, 12, // Points to the fourth f32 in 128-bit lane + 3, 2, 1, 0, // Points to the first f32 in 128-bit lane + 7, 6, 5, 4, // Points to the second f32 in 128-bit lane + 11, 10, 9, 8, // Points to the third f32 in 128-bit lane + 15, 14, 13, 12, // Points to the fourth f32 in 128-bit lane + 3, 2, 1, 0, // Points to the first f32 in 128-bit lane + 7, 6, 5, 4 // Points to the second f32 in 128-bit lane + ); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i))); + __m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i))); + ab_real_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_real_vec); + a_vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(a_vec), sign_flip_vec)); + b_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); + ab_imag_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_imag_vec); + } + + // In case the software emulation for `f16` scalars is enabled, the `simsimd_uncompress_f16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + __m128i f16_vec; + simsimd_f16_t f16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.f16[j] = a[i], b_padded_tail.f16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.f16[j] = 0, b_padded_tail.f16[j] = 0; + __m256 a_vec = _mm256_cvtph_ps(a_padded_tail.f16_vec); + __m256 b_vec = _mm256_cvtph_ps(b_padded_tail.f16_vec); + ab_real_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_real_vec); + a_vec = _mm256_castsi256_ps(_mm256_xor_si256(_mm256_castps_si256(a_vec), sign_flip_vec)); + b_vec = _mm256_castsi256_ps(_mm256_shuffle_epi8(_mm256_castps_si256(b_vec), swap_adjacent_vec)); + ab_imag_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_imag_vec); + } + + // Reduce horizontal sums: + results[0] = _mm256_reduce_add_ps_dbl(ab_real_vec); + results[1] = _mm256_reduce_add_ps_dbl(ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + __m256i ab_low_vec = _mm256_setzero_si256(); + __m256i ab_high_vec = _mm256_setzero_si256(); + + simsimd_size_t i = 0; + for (; i + 32 <= n; i += 32) { + __m256i a_vec = _mm256_loadu_si256((__m256i const*)(a + i)); + __m256i b_vec = _mm256_loadu_si256((__m256i const*)(b + i)); + + // Unpack int8 to int16 + __m256i a_low_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_vec, 0)); + __m256i a_high_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_vec, 1)); + __m256i b_low_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_vec, 0)); + __m256i b_high_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_vec, 1)); + + // Multiply and accumulate at int16 level, accumulate at int32 level + ab_low_vec = _mm256_add_epi32(ab_low_vec, _mm256_madd_epi16(a_low_16, b_low_16)); + ab_high_vec = _mm256_add_epi32(ab_high_vec, _mm256_madd_epi16(a_high_16, b_high_16)); + } + + // Horizontal sum across the 256-bit register + __m256i ab_vec = _mm256_add_epi32(ab_low_vec, ab_high_vec); + __m128i ab_sum = _mm_add_epi32(_mm256_extracti128_si256(ab_vec, 0), _mm256_extracti128_si256(ab_vec, 1)); + ab_sum = _mm_hadd_epi32(ab_sum, ab_sum); + ab_sum = _mm_hadd_epi32(ab_sum, ab_sum); + + // Take care of the tail: + int ab = _mm_extract_epi32(ab_sum, 0); + for (; i < n; ++i) + ab += a[i] * b[i]; + *result = ab; +} + +SIMSIMD_PUBLIC void simsimd_dot_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m256 ab_vec = _mm256_setzero_ps(); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + // Upcasting from `bf16` to `f32` is done by shifting the `bf16` values by 16 bits to the left, like: + // x = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(x), 16)) + __m256 a_vec = + _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i const*)(a + i))), 16)); + __m256 b_vec = + _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i const*)(b + i))), 16)); + ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); + } + + // In case the software emulation for `bf16` scalars is enabled, the `simsimd_uncompress_bf16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + __m128i bf16_vec; + simsimd_bf16_t bf16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.bf16[j] = a[i], b_padded_tail.bf16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.bf16[j] = 0, b_padded_tail.bf16[j] = 0; + __m256 a_vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a_padded_tail.bf16_vec), 16)); + __m256 b_vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(b_padded_tail.bf16_vec), 16)); + ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); + } + + ab_vec = _mm256_add_ps(_mm256_permute2f128_ps(ab_vec, ab_vec, 1), ab_vec); + ab_vec = _mm256_hadd_ps(ab_vec, ab_vec); + ab_vec = _mm256_hadd_ps(ab_vec, ab_vec); + + simsimd_f32_t f32_result; + _mm_store_ss(&f32_result, _mm256_castps256_ps128(ab_vec)); + *result = f32_result; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_HASWELL + +#if SIMSIMD_TARGET_SKYLAKE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "avx512bw", "bmi2") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,avx512bw,bmi2"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512 ab_vec = _mm512_setzero(); + __m512 a_vec, b_vec; + +simsimd_dot_f32_skylake_cycle: + if (n < 16) { + __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_ps(mask, a); + b_vec = _mm512_maskz_loadu_ps(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_ps(a); + b_vec = _mm512_loadu_ps(b); + a += 16, b += 16, n -= 16; + } + ab_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_vec); + if (n) + goto simsimd_dot_f32_skylake_cycle; + + *result = _mm512_reduce_add_ps(ab_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512d ab_vec = _mm512_setzero_pd(); + __m512d a_vec, b_vec; + +simsimd_dot_f64_skylake_cycle: + if (n < 8) { + __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_pd(mask, a); + b_vec = _mm512_maskz_loadu_pd(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_pd(a); + b_vec = _mm512_loadu_pd(b); + a += 8, b += 8, n -= 8; + } + ab_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_vec); + if (n) + goto simsimd_dot_f64_skylake_cycle; + + *result = _mm512_reduce_add_pd(ab_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + __m512 ab_real_vec = _mm512_setzero(); + __m512 ab_imag_vec = _mm512_setzero(); + __m512 a_vec; + __m512i b_vec; + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i sign_flip_vec = _mm512_set1_epi64(0x8000000000000000); + __m512i swap_adjacent_vec = _mm512_set_epi8( // + 59, 58, 57, 56, 63, 62, 61, 60, 51, 50, 49, 48, 55, 54, 53, 52, // 4th 128-bit lane + 43, 42, 41, 40, 47, 46, 45, 44, 35, 34, 33, 32, 39, 38, 37, 36, // 3rd 128-bit lane + 27, 26, 25, 24, 31, 30, 29, 28, 19, 18, 17, 16, 23, 22, 21, 20, // 2nd 128-bit lane + 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 // 1st 128-bit lane + ); +simsimd_dot_f32c_skylake_cycle: + if (n < 16) { + __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_ps(mask, a); + b_vec = _mm512_castps_si512(_mm512_maskz_loadu_ps(mask, b)); + n = 0; + } else { + a_vec = _mm512_loadu_ps(a); + b_vec = _mm512_castps_si512(_mm512_loadu_ps(b)); + a += 16, b += 16, n -= 16; + } + ab_real_vec = _mm512_fmadd_ps(_mm512_castsi512_ps(_mm512_xor_si512(b_vec, sign_flip_vec)), a_vec, ab_real_vec); + ab_imag_vec = + _mm512_fmadd_ps(_mm512_castsi512_ps(_mm512_shuffle_epi8(b_vec, swap_adjacent_vec)), a_vec, ab_imag_vec); + if (n) + goto simsimd_dot_f32c_skylake_cycle; + + // Reduce horizontal sums: + results[0] = _mm512_reduce_add_ps(ab_real_vec); + results[1] = _mm512_reduce_add_ps(ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_vdot_f32c_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + __m512 ab_real_vec = _mm512_setzero(); + __m512 ab_imag_vec = _mm512_setzero(); + __m512 a_vec; + __m512 b_vec; + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i sign_flip_vec = _mm512_set1_epi64(0x8000000000000000); + __m512i swap_adjacent_vec = _mm512_set_epi8( // + 59, 58, 57, 56, 63, 62, 61, 60, 51, 50, 49, 48, 55, 54, 53, 52, // 4th 128-bit lane + 43, 42, 41, 40, 47, 46, 45, 44, 35, 34, 33, 32, 39, 38, 37, 36, // 3rd 128-bit lane + 27, 26, 25, 24, 31, 30, 29, 28, 19, 18, 17, 16, 23, 22, 21, 20, // 2nd 128-bit lane + 11, 10, 9, 8, 15, 14, 13, 12, 3, 2, 1, 0, 7, 6, 5, 4 // 1st 128-bit lane + ); +simsimd_vdot_f32c_skylake_cycle: + if (n < 16) { + __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_ps(mask, a); + b_vec = _mm512_maskz_loadu_ps(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_ps(a); + b_vec = _mm512_loadu_ps(b); + a += 16, b += 16, n -= 16; + } + ab_real_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_real_vec); + a_vec = _mm512_castsi512_ps(_mm512_xor_si512(_mm512_castps_si512(a_vec), sign_flip_vec)); + b_vec = _mm512_castsi512_ps(_mm512_shuffle_epi8(_mm512_castps_si512(b_vec), swap_adjacent_vec)); + ab_imag_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_imag_vec); + if (n) + goto simsimd_vdot_f32c_skylake_cycle; + + // Reduce horizontal sums: + results[0] = _mm512_reduce_add_ps(ab_real_vec); + results[1] = _mm512_reduce_add_ps(ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + __m512d ab_real_vec = _mm512_setzero_pd(); + __m512d ab_imag_vec = _mm512_setzero_pd(); + __m512d a_vec; + __m512i b_vec; + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i sign_flip_vec = _mm512_set_epi64( // + 0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000, // + 0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000 // + ); + __m512i swap_adjacent_vec = _mm512_set_epi8( // + 55, 54, 53, 52, 51, 50, 49, 48, 63, 62, 61, 60, 59, 58, 57, 56, // 4th 128-bit lane + 39, 38, 37, 36, 35, 34, 33, 32, 47, 46, 45, 44, 43, 42, 41, 40, // 3rd 128-bit lane + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24, // 2nd 128-bit lane + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 // 1st 128-bit lane + ); +simsimd_dot_f64c_skylake_cycle: + if (n < 8) { + __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_pd(mask, a); + b_vec = _mm512_castpd_si512(_mm512_maskz_loadu_pd(mask, b)); + n = 0; + } else { + a_vec = _mm512_loadu_pd(a); + b_vec = _mm512_castpd_si512(_mm512_loadu_pd(b)); + a += 8, b += 8, n -= 8; + } + ab_real_vec = _mm512_fmadd_pd(_mm512_castsi512_pd(_mm512_xor_si512(b_vec, sign_flip_vec)), a_vec, ab_real_vec); + ab_imag_vec = + _mm512_fmadd_pd(_mm512_castsi512_pd(_mm512_shuffle_epi8(b_vec, swap_adjacent_vec)), a_vec, ab_imag_vec); + if (n) + goto simsimd_dot_f64c_skylake_cycle; + + // Reduce horizontal sums: + results[0] = _mm512_reduce_add_pd(ab_real_vec); + results[1] = _mm512_reduce_add_pd(ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_vdot_f64c_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + __m512d ab_real_vec = _mm512_setzero_pd(); + __m512d ab_imag_vec = _mm512_setzero_pd(); + __m512d a_vec; + __m512d b_vec; + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i sign_flip_vec = _mm512_set_epi64( // + 0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000, // + 0x8000000000000000, 0x0000000000000000, 0x8000000000000000, 0x0000000000000000 // + ); + __m512i swap_adjacent_vec = _mm512_set_epi8( // + 55, 54, 53, 52, 51, 50, 49, 48, 63, 62, 61, 60, 59, 58, 57, 56, // 4th 128-bit lane + 39, 38, 37, 36, 35, 34, 33, 32, 47, 46, 45, 44, 43, 42, 41, 40, // 3rd 128-bit lane + 23, 22, 21, 20, 19, 18, 17, 16, 31, 30, 29, 28, 27, 26, 25, 24, // 2nd 128-bit lane + 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8 // 1st 128-bit lane + ); +simsimd_vdot_f64c_skylake_cycle: + if (n < 8) { + __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_pd(mask, a); + b_vec = _mm512_maskz_loadu_pd(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_pd(a); + b_vec = _mm512_loadu_pd(b); + a += 8, b += 8, n -= 8; + } + ab_real_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_real_vec); + a_vec = _mm512_castsi512_pd(_mm512_xor_si512(_mm512_castpd_si512(a_vec), sign_flip_vec)); + b_vec = _mm512_castsi512_pd(_mm512_shuffle_epi8(_mm512_castpd_si512(b_vec), swap_adjacent_vec)); + ab_imag_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_imag_vec); + if (n) + goto simsimd_vdot_f64c_skylake_cycle; + + // Reduce horizontal sums: + results[0] = _mm512_reduce_add_pd(ab_real_vec); + results[1] = _mm512_reduce_add_pd(ab_imag_vec); +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SKYLAKE + +#if SIMSIMD_TARGET_GENOA +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512bw", "avx512bf16") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512bw,avx512bf16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512 ab_vec = _mm512_setzero_ps(); + __m512i a_i16_vec, b_i16_vec; + +simsimd_dot_bf16_genoa_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); + b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); + n = 0; + } else { + a_i16_vec = _mm512_loadu_epi16(a); + b_i16_vec = _mm512_loadu_epi16(b); + a += 32, b += 32, n -= 32; + } + ab_vec = _mm512_dpbf16_ps(ab_vec, (__m512bh)(a_i16_vec), (__m512bh)(b_i16_vec)); + if (n) + goto simsimd_dot_bf16_genoa_cycle; + + *result = _mm512_reduce_add_ps(ab_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + __m512 ab_real_vec = _mm512_setzero_ps(); + __m512 ab_imag_vec = _mm512_setzero_ps(); + __m512i a_vec; + __m512i b_vec; + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i sign_flip_vec = _mm512_set1_epi32(0x80000000); + __m512i swap_adjacent_vec = _mm512_set_epi8( // + 61, 60, 63, 62, 57, 56, 59, 58, 53, 52, 55, 54, 49, 48, 51, 50, // 4th 128-bit lane + 45, 44, 47, 46, 41, 40, 43, 42, 37, 36, 39, 38, 33, 32, 35, 34, // 3rd 128-bit lane + 29, 28, 31, 30, 25, 24, 27, 26, 21, 20, 23, 22, 17, 16, 19, 18, // 2nd 128-bit lane + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 // 1st 128-bit lane + ); + +simsimd_dot_bf16c_genoa_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_epi16(mask, a); + b_vec = _mm512_maskz_loadu_epi16(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_epi16(a); + b_vec = _mm512_loadu_epi16(b); + a += 32, b += 32, n -= 32; + } + ab_real_vec = _mm512_dpbf16_ps(ab_real_vec, (__m512bh)(_mm512_xor_si512(b_vec, sign_flip_vec)), (__m512bh)(a_vec)); + ab_imag_vec = + _mm512_dpbf16_ps(ab_imag_vec, (__m512bh)(_mm512_shuffle_epi8(b_vec, swap_adjacent_vec)), (__m512bh)(a_vec)); + if (n) + goto simsimd_dot_bf16c_genoa_cycle; + + // Reduce horizontal sums: + results[0] = _mm512_reduce_add_ps(ab_real_vec); + results[1] = _mm512_reduce_add_ps(ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_vdot_bf16c_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + __m512 ab_real_vec = _mm512_setzero_ps(); + __m512 ab_imag_vec = _mm512_setzero_ps(); + __m512i a_vec; + __m512i b_vec; + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i sign_flip_vec = _mm512_set1_epi32(0x80000000); + __m512i swap_adjacent_vec = _mm512_set_epi8( // + 61, 60, 63, 62, 57, 56, 59, 58, 53, 52, 55, 54, 49, 48, 51, 50, // 4th 128-bit lane + 45, 44, 47, 46, 41, 40, 43, 42, 37, 36, 39, 38, 33, 32, 35, 34, // 3rd 128-bit lane + 29, 28, 31, 30, 25, 24, 27, 26, 21, 20, 23, 22, 17, 16, 19, 18, // 2nd 128-bit lane + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 // 1st 128-bit lane + ); + +simsimd_dot_bf16c_genoa_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_epi16(mask, a); + b_vec = _mm512_maskz_loadu_epi16(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_epi16(a); + b_vec = _mm512_loadu_epi16(b); + a += 32, b += 32, n -= 32; + } + ab_real_vec = _mm512_dpbf16_ps(ab_real_vec, (__m512bh)(a_vec), (__m512bh)(b_vec)); + a_vec = _mm512_xor_si512(a_vec, sign_flip_vec); + b_vec = _mm512_shuffle_epi8(b_vec, swap_adjacent_vec); + ab_imag_vec = _mm512_dpbf16_ps(ab_imag_vec, (__m512bh)(a_vec), (__m512bh)(b_vec)); + if (n) + goto simsimd_dot_bf16c_genoa_cycle; + + // Reduce horizontal sums: + results[0] = _mm512_reduce_add_ps(ab_real_vec); + results[1] = _mm512_reduce_add_ps(ab_imag_vec); +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_GENOA + +#if SIMSIMD_TARGET_SAPPHIRE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512bw", "avx512fp16") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512bw,avx512fp16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512h ab_vec = _mm512_setzero_ph(); + __m512i a_i16_vec, b_i16_vec; + +simsimd_dot_f16_sapphire_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); + b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); + n = 0; + } else { + a_i16_vec = _mm512_loadu_epi16(a); + b_i16_vec = _mm512_loadu_epi16(b); + a += 32, b += 32, n -= 32; + } + ab_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_i16_vec), _mm512_castsi512_ph(b_i16_vec), ab_vec); + if (n) + goto simsimd_dot_f16_sapphire_cycle; + + *result = _mm512_reduce_add_ph(ab_vec); +} + +SIMSIMD_PUBLIC void simsimd_dot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + __m512h ab_real_vec = _mm512_setzero_ph(); + __m512h ab_imag_vec = _mm512_setzero_ph(); + __m512i a_vec; + __m512i b_vec; + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i sign_flip_vec = _mm512_set1_epi32(0x80000000); + __m512i swap_adjacent_vec = _mm512_set_epi8( // + 61, 60, 63, 62, 57, 56, 59, 58, 53, 52, 55, 54, 49, 48, 51, 50, // 4th 128-bit lane + 45, 44, 47, 46, 41, 40, 43, 42, 37, 36, 39, 38, 33, 32, 35, 34, // 3rd 128-bit lane + 29, 28, 31, 30, 25, 24, 27, 26, 21, 20, 23, 22, 17, 16, 19, 18, // 2nd 128-bit lane + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 // 1st 128-bit lane + ); + +simsimd_dot_f16c_sapphire_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_epi16(mask, a); + b_vec = _mm512_maskz_loadu_epi16(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_epi16(a); + b_vec = _mm512_loadu_epi16(b); + a += 32, b += 32, n -= 32; + } + ab_real_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(_mm512_xor_si512(b_vec, sign_flip_vec)), + _mm512_castsi512_ph(a_vec), ab_real_vec); + ab_imag_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(_mm512_shuffle_epi8(b_vec, swap_adjacent_vec)), + _mm512_castsi512_ph(a_vec), ab_imag_vec); + if (n) + goto simsimd_dot_f16c_sapphire_cycle; + + // Reduce horizontal sums: + results[0] = _mm512_reduce_add_ph(ab_real_vec); + results[1] = _mm512_reduce_add_ph(ab_imag_vec); +} + +SIMSIMD_PUBLIC void simsimd_vdot_f16c_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* results) { + + __m512h ab_real_vec = _mm512_setzero_ph(); + __m512h ab_imag_vec = _mm512_setzero_ph(); + __m512i a_vec; + __m512i b_vec; + + // We take into account, that FMS is the same as FMA with a negative multiplier. + // To multiply a floating-point value by -1, we can use the `XOR` instruction to flip the sign bit. + // This way we can avoid the shuffling and the need for separate real and imaginary parts. + // For the imaginary part of the product, we would need to swap the real and imaginary parts of + // one of the vectors. + __m512i sign_flip_vec = _mm512_set1_epi32(0x80000000); + __m512i swap_adjacent_vec = _mm512_set_epi8( // + 61, 60, 63, 62, 57, 56, 59, 58, 53, 52, 55, 54, 49, 48, 51, 50, // 4th 128-bit lane + 45, 44, 47, 46, 41, 40, 43, 42, 37, 36, 39, 38, 33, 32, 35, 34, // 3rd 128-bit lane + 29, 28, 31, 30, 25, 24, 27, 26, 21, 20, 23, 22, 17, 16, 19, 18, // 2nd 128-bit lane + 13, 12, 15, 14, 9, 8, 11, 10, 5, 4, 7, 6, 1, 0, 3, 2 // 1st 128-bit lane + ); + +simsimd_dot_f16c_sapphire_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_epi16(mask, a); + b_vec = _mm512_maskz_loadu_epi16(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_epi16(a); + b_vec = _mm512_loadu_epi16(b); + a += 32, b += 32, n -= 32; + } + ab_real_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_vec), _mm512_castsi512_ph(b_vec), ab_real_vec); + a_vec = _mm512_xor_si512(a_vec, sign_flip_vec); + b_vec = _mm512_shuffle_epi8(b_vec, swap_adjacent_vec); + ab_imag_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_vec), _mm512_castsi512_ph(b_vec), ab_imag_vec); + if (n) + goto simsimd_dot_f16c_sapphire_cycle; + + // Reduce horizontal sums: + results[0] = _mm512_reduce_add_ph(ab_real_vec); + results[1] = _mm512_reduce_add_ph(ab_imag_vec); +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SAPPHIRE + +#if SIMSIMD_TARGET_ICE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512bw", "avx512vnni") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512bw,avx512vnni"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_dot_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512i ab_i32s_vec = _mm512_setzero_si512(); + __m512i a_vec, b_vec; + +simsimd_dot_i8_ice_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, a)); + b_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, b)); + n = 0; + } else { + a_vec = _mm512_cvtepi8_epi16(_mm256_loadu_epi8(a)); + b_vec = _mm512_cvtepi8_epi16(_mm256_loadu_epi8(b)); + a += 32, b += 32, n -= 32; + } + // Unfortunately we can't use the `_mm512_dpbusd_epi32` intrinsics here either, + // as it's asymmetric with respect to the sign of the input arguments: + // Signed(ZeroExtend16(a.byte[4*j]) * SignExtend16(b.byte[4*j])) + // So we have to use the `_mm512_dpwssd_epi32` intrinsics instead, upcasting + // to 16-bit beforehand. + ab_i32s_vec = _mm512_dpwssd_epi32(ab_i32s_vec, a_vec, b_vec); + if (n) + goto simsimd_dot_i8_ice_cycle; + + *result = _mm512_reduce_add_epi32(ab_i32s_vec); +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_ICE +#endif // SIMSIMD_TARGET_X86 + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/src/include/simsimd/geospatial.h b/src/include/simsimd/geospatial.h new file mode 100644 index 0000000..fe2c955 --- /dev/null +++ b/src/include/simsimd/geospatial.h @@ -0,0 +1,37 @@ +/** + * @file geospatial.h + * @brief SIMD-accelerated Geo-Spatial distance functions. + * @author Ash Vardanian + * @date July 1, 2023 + * + * Contains: + * - Haversine (Great Circle) distance + * - TODO: Vincenty's distance function for Oblate Spheroid Geodesics + * + * For datatypes: + * - 32-bit IEEE-754 floating point + * - 64-bit IEEE-754 floating point + * + * For hardware architectures: + * - Arm (NEON, SVE) + * - x86 (AVX512) + * + * x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/ + * Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/ + * Oblate Spheroid Geodesic: https://mathworld.wolfram.com/OblateSpheroidGeodesic.html + * Staging experiments: https://github.com/ashvardanian/HaversineSimSIMD + */ +#ifndef SIMSIMD_GEOSPATIAL_H +#define SIMSIMD_GEOSPATIAL_H + +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/src/include/simsimd/probability.h b/src/include/simsimd/probability.h new file mode 100644 index 0000000..962a1d9 --- /dev/null +++ b/src/include/simsimd/probability.h @@ -0,0 +1,573 @@ +/** + * @file probability.h + * @brief SIMD-accelerated Similarity Measures for Probability Distributions. + * @author Ash Vardanian + * @date October 20, 2023 + * + * Contains: + * - Kullback-Leibler divergence + * - Jensen–Shannon divergence + * + * For datatypes: + * - 32-bit floating point numbers + * - 16-bit floating point numbers + * + * For hardware architectures: + * - Arm (NEON, SVE) + * - x86 (AVX2, AVX512) + * + * x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/ + * Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/ + */ +#ifndef SIMSIMD_PROBABILITY_H +#define SIMSIMD_PROBABILITY_H + +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// clang-format off + +/* Serial backends for all numeric types. + * By default they use 32-bit arithmetic, unless the arguments themselves contain 64-bit floats. + * For double-precision computation check out the "*_accurate" variants of those "*_serial" functions. + */ +SIMSIMD_PUBLIC void simsimd_kl_f64_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_f64_serial(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_kl_f32_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_f32_serial(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_kl_f16_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_f16_serial(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_kl_bf16_serial(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_bf16_serial(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); + +/* Double-precision serial backends for all numeric types. + * For single-precision computation check out the "*_serial" counterparts of those "*_accurate" functions. + */ +SIMSIMD_PUBLIC void simsimd_kl_f32_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_f32_accurate(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_kl_f16_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_f16_accurate(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_kl_bf16_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_bf16_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); + +/* SIMD-powered backends for Arm NEON, mostly using 32-bit arithmetic over 128-bit words. + * By far the most portable backend, covering most Arm v8 devices, over a billion phones, and almost all + * server CPUs produced before 2023. + */ +SIMSIMD_PUBLIC void simsimd_kl_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_kl_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); + +/* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer, using 32-bit arithmetic over 256-bit words. + * First demonstrated in 2011, at least one Haswell-based processor was still being sold in 2022 — the Pentium G3420. + * Practically all modern x86 CPUs support AVX2, FMA, and F16C, making it a perfect baseline for SIMD algorithms. + * On other hand, there is no need to implement AVX2 versions of `f32` and `f64` functions, as those are + * properly vectorized by recent compilers. + */ +SIMSIMD_PUBLIC void simsimd_kl_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); + +/* SIMD-powered backends for various generations of AVX512 CPUs. + * Skylake is handy, as it supports masked loads and other operations, avoiding the need for the tail loop. + * Ice Lake added VNNI, VPOPCNTDQ, IFMA, VBMI, VAES, GFNI, VBMI2, BITALG, VPCLMULQDQ, and other extensions for integral operations. + * Sapphire Rapids added tiled matrix operations, but we are most interested in the new mixed-precision FMA instructions. + */ +SIMSIMD_PUBLIC void simsimd_kl_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_kl_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* divergence); +// clang-format on + +#define SIMSIMD_MAKE_KL(name, input_type, accumulator_type, converter, epsilon) \ + SIMSIMD_PUBLIC void simsimd_kl_##input_type##_##name(simsimd_##input_type##_t const* a, \ + simsimd_##input_type##_t const* b, simsimd_size_t n, \ + simsimd_distance_t* result) { \ + simsimd_##accumulator_type##_t d = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = converter(a[i]); \ + simsimd_##accumulator_type##_t bi = converter(b[i]); \ + d += ai * SIMSIMD_LOG((ai + epsilon) / (bi + epsilon)); \ + } \ + *result = (simsimd_distance_t)d; \ + } + +#define SIMSIMD_MAKE_JS(name, input_type, accumulator_type, converter, epsilon) \ + SIMSIMD_PUBLIC void simsimd_js_##input_type##_##name(simsimd_##input_type##_t const* a, \ + simsimd_##input_type##_t const* b, simsimd_size_t n, \ + simsimd_distance_t* result) { \ + simsimd_##accumulator_type##_t d = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = converter(a[i]); \ + simsimd_##accumulator_type##_t bi = converter(b[i]); \ + simsimd_##accumulator_type##_t mi = (ai + bi) / 2; \ + d += ai * SIMSIMD_LOG((ai + epsilon) / (mi + epsilon)); \ + d += bi * SIMSIMD_LOG((bi + epsilon) / (mi + epsilon)); \ + } \ + *result = (simsimd_distance_t)d / 2; \ + } + +SIMSIMD_MAKE_KL(serial, f64, f64, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f64_serial +SIMSIMD_MAKE_JS(serial, f64, f64, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_f64_serial + +SIMSIMD_MAKE_KL(serial, f32, f32, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f32_serial +SIMSIMD_MAKE_JS(serial, f32, f32, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_f32_serial + +SIMSIMD_MAKE_KL(serial, f16, f32, SIMSIMD_UNCOMPRESS_F16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f16_serial +SIMSIMD_MAKE_JS(serial, f16, f32, SIMSIMD_UNCOMPRESS_F16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_f16_serial + +SIMSIMD_MAKE_KL(serial, bf16, f32, SIMSIMD_UNCOMPRESS_BF16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_bf16_serial +SIMSIMD_MAKE_JS(serial, bf16, f32, SIMSIMD_UNCOMPRESS_BF16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_bf16_serial + +SIMSIMD_MAKE_KL(accurate, f32, f64, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f32_accurate +SIMSIMD_MAKE_JS(accurate, f32, f64, SIMSIMD_IDENTIFY, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_f32_accurate + +SIMSIMD_MAKE_KL(accurate, f16, f64, SIMSIMD_UNCOMPRESS_F16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_f16_accurate +SIMSIMD_MAKE_JS(accurate, f16, f64, SIMSIMD_UNCOMPRESS_F16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_f16_accurate + +SIMSIMD_MAKE_KL(accurate, bf16, f64, SIMSIMD_UNCOMPRESS_BF16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_kl_bf16_accurate +SIMSIMD_MAKE_JS(accurate, bf16, f64, SIMSIMD_UNCOMPRESS_BF16, SIMSIMD_F32_DIVISION_EPSILON) // simsimd_js_bf16_accurate + +#if SIMSIMD_TARGET_ARM +#if SIMSIMD_TARGET_NEON +#pragma GCC push_options +#pragma GCC target("+simd") +#pragma clang attribute push(__attribute__((target("+simd"))), apply_to = function) + +SIMSIMD_PUBLIC float32x4_t simsimd_log2_f32_neon(float32x4_t x) { + // Extracting the exponent + int32x4_t i = vreinterpretq_s32_f32(x); + int32x4_t e = vsubq_s32(vshrq_n_s32(vandq_s32(i, vdupq_n_s32(0x7F800000)), 23), vdupq_n_s32(127)); + float32x4_t e_float = vcvtq_f32_s32(e); + + // Extracting the mantissa + float32x4_t m = vreinterpretq_f32_s32(vorrq_s32(vandq_s32(i, vdupq_n_s32(0x007FFFFF)), vdupq_n_s32(0x3F800000))); + + // Constants for polynomial + float32x4_t one = vdupq_n_f32(1.0f); + float32x4_t p = vdupq_n_f32(-3.4436006e-2f); + + // Compute polynomial using Horner's method + p = vmlaq_f32(vdupq_n_f32(3.1821337e-1f), m, p); + p = vmlaq_f32(vdupq_n_f32(-1.2315303f), m, p); + p = vmlaq_f32(vdupq_n_f32(2.5988452f), m, p); + p = vmlaq_f32(vdupq_n_f32(-3.3241990f), m, p); + p = vmlaq_f32(vdupq_n_f32(3.1157899f), m, p); + + // Final computation + float32x4_t result = vaddq_f32(vmulq_f32(p, vsubq_f32(m, one)), e_float); + return result; +} + +SIMSIMD_PUBLIC void simsimd_kl_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t sum_vec = vdupq_n_f32(0); + simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; + float32x4_t epsilon_vec = vdupq_n_f32(epsilon); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vld1q_f32(a + i); + float32x4_t b_vec = vld1q_f32(b + i); + float32x4_t ratio_vec = vdivq_f32(vaddq_f32(a_vec, epsilon_vec), vaddq_f32(b_vec, epsilon_vec)); + float32x4_t log_ratio_vec = simsimd_log2_f32_neon(ratio_vec); + float32x4_t prod_vec = vmulq_f32(a_vec, log_ratio_vec); + sum_vec = vaddq_f32(sum_vec, prod_vec); + } + simsimd_f32_t log2_normalizer = 0.693147181f; + simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer; + for (; i < n; ++i) + sum += a[i] * SIMSIMD_LOG((a[i] + epsilon) / (b[i] + epsilon)); + *result = sum; +} + +SIMSIMD_PUBLIC void simsimd_js_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t sum_vec = vdupq_n_f32(0); + simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; + float32x4_t epsilon_vec = vdupq_n_f32(epsilon); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vld1q_f32(a + i); + float32x4_t b_vec = vld1q_f32(b + i); + float32x4_t m_vec = vmulq_f32(vaddq_f32(a_vec, b_vec), vdupq_n_f32(0.5)); + float32x4_t ratio_a_vec = vdivq_f32(vaddq_f32(a_vec, epsilon_vec), vaddq_f32(m_vec, epsilon_vec)); + float32x4_t ratio_b_vec = vdivq_f32(vaddq_f32(b_vec, epsilon_vec), vaddq_f32(m_vec, epsilon_vec)); + float32x4_t log_ratio_a_vec = simsimd_log2_f32_neon(ratio_a_vec); + float32x4_t log_ratio_b_vec = simsimd_log2_f32_neon(ratio_b_vec); + float32x4_t prod_a_vec = vmulq_f32(a_vec, log_ratio_a_vec); + float32x4_t prod_b_vec = vmulq_f32(b_vec, log_ratio_b_vec); + sum_vec = vaddq_f32(sum_vec, vaddq_f32(prod_a_vec, prod_b_vec)); + } + simsimd_f32_t log2_normalizer = 0.693147181f; + simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer; + for (; i < n; ++i) { + simsimd_f32_t mi = 0.5f * (a[i] + b[i]); + sum += a[i] * SIMSIMD_LOG((a[i] + epsilon) / (mi + epsilon)); + sum += b[i] * SIMSIMD_LOG((b[i] + epsilon) / (mi + epsilon)); + } + *result = sum; +} + +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("+simd+fp16") +#pragma clang attribute push(__attribute__((target("+simd+fp16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_kl_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t sum_vec = vdupq_n_f32(0); + simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; + float32x4_t epsilon_vec = vdupq_n_f32(epsilon); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a + i)); + float32x4_t b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b + i)); + float32x4_t ratio_vec = vdivq_f32(vaddq_f32(a_vec, epsilon_vec), vaddq_f32(b_vec, epsilon_vec)); + float32x4_t log_ratio_vec = simsimd_log2_f32_neon(ratio_vec); + float32x4_t prod_vec = vmulq_f32(a_vec, log_ratio_vec); + sum_vec = vaddq_f32(sum_vec, prod_vec); + } + simsimd_f32_t log2_normalizer = 0.693147181f; + simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer; + for (; i < n; ++i) + sum += SIMSIMD_UNCOMPRESS_F16(a[i]) * + SIMSIMD_LOG((SIMSIMD_UNCOMPRESS_F16(a[i]) + epsilon) / (SIMSIMD_UNCOMPRESS_F16(b[i]) + epsilon)); + *result = sum; +} + +SIMSIMD_PUBLIC void simsimd_js_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t sum_vec = vdupq_n_f32(0); + simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; + float32x4_t epsilon_vec = vdupq_n_f32(epsilon); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a + i)); + float32x4_t b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b + i)); + float32x4_t m_vec = vmulq_f32(vaddq_f32(a_vec, b_vec), vdupq_n_f32(0.5)); + float32x4_t ratio_a_vec = vdivq_f32(vaddq_f32(a_vec, epsilon_vec), vaddq_f32(m_vec, epsilon_vec)); + float32x4_t ratio_b_vec = vdivq_f32(vaddq_f32(b_vec, epsilon_vec), vaddq_f32(m_vec, epsilon_vec)); + float32x4_t log_ratio_a_vec = simsimd_log2_f32_neon(ratio_a_vec); + float32x4_t log_ratio_b_vec = simsimd_log2_f32_neon(ratio_b_vec); + float32x4_t prod_a_vec = vmulq_f32(a_vec, log_ratio_a_vec); + float32x4_t prod_b_vec = vmulq_f32(b_vec, log_ratio_b_vec); + sum_vec = vaddq_f32(sum_vec, vaddq_f32(prod_a_vec, prod_b_vec)); + } + simsimd_f32_t log2_normalizer = 0.693147181f; + simsimd_f32_t sum = vaddvq_f32(sum_vec) * log2_normalizer; + for (; i < n; ++i) { + simsimd_f32_t ai = SIMSIMD_UNCOMPRESS_F16(a[i]); + simsimd_f32_t bi = SIMSIMD_UNCOMPRESS_F16(b[i]); + simsimd_f32_t mi = 0.5f * (ai + bi); + sum += ai * SIMSIMD_LOG((ai + epsilon) / (mi + epsilon)); + sum += bi * SIMSIMD_LOG((bi + epsilon) / (mi + epsilon)); + } + *result = sum; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_NEON +#endif // SIMSIMD_TARGET_ARM + +#if SIMSIMD_TARGET_X86 +#if SIMSIMD_TARGET_HASWELL +#pragma GCC push_options +#pragma GCC target("avx2", "f16c", "fma") +#pragma clang attribute push(__attribute__((target("avx2,f16c,fma"))), apply_to = function) + +inline __m256 simsimd_log2_f32_haswell(__m256 x) { + // Extracting the exponent + __m256i i = _mm256_castps_si256(x); + __m256i e = _mm256_srli_epi32(_mm256_and_si256(i, _mm256_set1_epi32(0x7F800000)), 23); + e = _mm256_sub_epi32(e, _mm256_set1_epi32(127)); // removing the bias + __m256 e_float = _mm256_cvtepi32_ps(e); + + // Extracting the mantissa + __m256 m = _mm256_castsi256_ps( + _mm256_or_si256(_mm256_and_si256(i, _mm256_set1_epi32(0x007FFFFF)), _mm256_set1_epi32(0x3F800000))); + + // Constants for polynomial + __m256 one = _mm256_set1_ps(1.0f); + __m256 p = _mm256_set1_ps(-3.4436006e-2f); + + // Compute the polynomial using Horner's method + p = _mm256_fmadd_ps(m, p, _mm256_set1_ps(3.1821337e-1f)); + p = _mm256_fmadd_ps(m, p, _mm256_set1_ps(-1.2315303f)); + p = _mm256_fmadd_ps(m, p, _mm256_set1_ps(2.5988452f)); + p = _mm256_fmadd_ps(m, p, _mm256_set1_ps(-3.3241990f)); + p = _mm256_fmadd_ps(m, p, _mm256_set1_ps(3.1157899f)); + + // Final computation + __m256 result = _mm256_add_ps(_mm256_mul_ps(p, _mm256_sub_ps(m, one)), e_float); + return result; +} + +SIMSIMD_PUBLIC void simsimd_kl_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m256 sum_vec = _mm256_setzero_ps(); + simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; + __m256 epsilon_vec = _mm256_set1_ps(epsilon); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i))); + __m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i))); + __m256 ratio_vec = _mm256_div_ps(_mm256_add_ps(a_vec, epsilon_vec), _mm256_add_ps(b_vec, epsilon_vec)); + __m256 log_ratio_vec = simsimd_log2_f32_haswell(ratio_vec); + __m256 prod_vec = _mm256_mul_ps(a_vec, log_ratio_vec); + sum_vec = _mm256_add_ps(sum_vec, prod_vec); + } + + sum_vec = _mm256_add_ps(_mm256_permute2f128_ps(sum_vec, sum_vec, 1), sum_vec); + sum_vec = _mm256_hadd_ps(sum_vec, sum_vec); + sum_vec = _mm256_hadd_ps(sum_vec, sum_vec); + + simsimd_f32_t log2_normalizer = 0.693147181f; + simsimd_f32_t sum; + _mm_store_ss(&sum, _mm256_castps256_ps128(sum_vec)); + sum *= log2_normalizer; + + // Accumulate the tail: + for (; i < n; ++i) + sum += SIMSIMD_UNCOMPRESS_F16(a[i]) * + SIMSIMD_LOG((SIMSIMD_UNCOMPRESS_F16(a[i]) + epsilon) / (SIMSIMD_UNCOMPRESS_F16(b[i]) + epsilon)); + *result = sum; +} + +SIMSIMD_PUBLIC void simsimd_js_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m256 sum_vec = _mm256_setzero_ps(); + simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; + __m256 epsilon_vec = _mm256_set1_ps(epsilon); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_add_ps(_mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i))), epsilon_vec); + __m256 b_vec = _mm256_add_ps(_mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i))), epsilon_vec); + __m256 m_vec = _mm256_mul_ps(_mm256_add_ps(a_vec, b_vec), _mm256_set1_ps(0.5f)); // M = (P + Q) / 2 + __m256 ratio_a_vec = _mm256_div_ps(a_vec, m_vec); + __m256 ratio_b_vec = _mm256_div_ps(b_vec, m_vec); + __m256 log_ratio_a_vec = simsimd_log2_f32_haswell(ratio_a_vec); + __m256 log_ratio_b_vec = simsimd_log2_f32_haswell(ratio_b_vec); + __m256 prod_a_vec = _mm256_mul_ps(a_vec, log_ratio_a_vec); + __m256 prod_b_vec = _mm256_mul_ps(b_vec, log_ratio_b_vec); + sum_vec = _mm256_add_ps(sum_vec, prod_a_vec); + sum_vec = _mm256_add_ps(sum_vec, prod_b_vec); + } + + sum_vec = _mm256_add_ps(_mm256_permute2f128_ps(sum_vec, sum_vec, 1), sum_vec); + sum_vec = _mm256_hadd_ps(sum_vec, sum_vec); + sum_vec = _mm256_hadd_ps(sum_vec, sum_vec); + + simsimd_f32_t log2_normalizer = 0.693147181f; + simsimd_f32_t sum; + _mm_store_ss(&sum, _mm256_castps256_ps128(sum_vec)); + sum *= log2_normalizer; + + // Accumulate the tail: + for (; i < n; ++i) { + simsimd_f32_t ai = SIMSIMD_UNCOMPRESS_F16(a[i]); + simsimd_f32_t bi = SIMSIMD_UNCOMPRESS_F16(b[i]); + simsimd_f32_t mi = ai + bi; + sum += ai * SIMSIMD_LOG((ai + epsilon) / (mi + epsilon)); + sum += bi * SIMSIMD_LOG((bi + epsilon) / (mi + epsilon)); + } + *result = sum / 2; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_HASWELL + +#if SIMSIMD_TARGET_SKYLAKE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2"))), apply_to = function) + +inline __m512 simsimd_log2_f32_skylake(__m512 x) { + // Extract the exponent and mantissa + __m512 one = _mm512_set1_ps(1.0f); + __m512 e = _mm512_getexp_ps(x); + __m512 m = _mm512_getmant_ps(x, _MM_MANT_NORM_1_2, _MM_MANT_SIGN_src); + + // Compute the polynomial using Horner's method + __m512 p = _mm512_set1_ps(-3.4436006e-2f); + p = _mm512_fmadd_ps(m, p, _mm512_set1_ps(3.1821337e-1f)); + p = _mm512_fmadd_ps(m, p, _mm512_set1_ps(-1.2315303f)); + p = _mm512_fmadd_ps(m, p, _mm512_set1_ps(2.5988452f)); + p = _mm512_fmadd_ps(m, p, _mm512_set1_ps(-3.3241990f)); + p = _mm512_fmadd_ps(m, p, _mm512_set1_ps(3.1157899f)); + + return _mm512_add_ps(_mm512_mul_ps(p, _mm512_sub_ps(m, one)), e); +} + +SIMSIMD_PUBLIC void simsimd_kl_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512 sum_vec = _mm512_setzero(); + simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; + __m512 epsilon_vec = _mm512_set1_ps(epsilon); + __m512 a_vec, b_vec; + +simsimd_kl_f32_skylake_cycle: + if (n < 16) { + __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_add_ps(_mm512_maskz_loadu_ps(mask, a), epsilon_vec); + b_vec = _mm512_add_ps(_mm512_maskz_loadu_ps(mask, b), epsilon_vec); + n = 0; + } else { + a_vec = _mm512_add_ps(_mm512_loadu_ps(a), epsilon_vec); + b_vec = _mm512_add_ps(_mm512_loadu_ps(b), epsilon_vec); + a += 16, b += 16, n -= 16; + } + __m512 ratio_vec = _mm512_div_ps(a_vec, b_vec); + __m512 log_ratio_vec = simsimd_log2_f32_skylake(ratio_vec); + __m512 prod_vec = _mm512_mul_ps(a_vec, log_ratio_vec); + sum_vec = _mm512_add_ps(sum_vec, prod_vec); + if (n) + goto simsimd_kl_f32_skylake_cycle; + + simsimd_f32_t log2_normalizer = 0.693147181f; + *result = _mm512_reduce_add_ps(sum_vec) * log2_normalizer; +} + +SIMSIMD_PUBLIC void simsimd_js_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512 sum_a_vec = _mm512_setzero(); + __m512 sum_b_vec = _mm512_setzero(); + simsimd_f32_t epsilon = SIMSIMD_F32_DIVISION_EPSILON; + __m512 epsilon_vec = _mm512_set1_ps(epsilon); + __m512 a_vec, b_vec; + +simsimd_js_f32_skylake_cycle: + if (n < 16) { + __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_ps(mask, a); + b_vec = _mm512_maskz_loadu_ps(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_ps(a); + b_vec = _mm512_loadu_ps(b); + a += 16, b += 16, n -= 16; + } + __m512 m_vec = _mm512_mul_ps(_mm512_add_ps(a_vec, b_vec), _mm512_set1_ps(0.5f)); + __mmask16 nonzero_mask_a = _mm512_cmp_ps_mask(a_vec, epsilon_vec, _CMP_GE_OQ); + __mmask16 nonzero_mask_b = _mm512_cmp_ps_mask(b_vec, epsilon_vec, _CMP_GE_OQ); + __mmask16 nonzero_mask = nonzero_mask_a & nonzero_mask_b; + __m512 m_recip_approx = _mm512_rcp14_ps(m_vec); + __m512 ratio_a_vec = _mm512_mul_ps(a_vec, m_recip_approx); + __m512 ratio_b_vec = _mm512_mul_ps(b_vec, m_recip_approx); + __m512 log_ratio_a_vec = simsimd_log2_f32_skylake(ratio_a_vec); + __m512 log_ratio_b_vec = simsimd_log2_f32_skylake(ratio_b_vec); + sum_a_vec = _mm512_maskz_fmadd_ps(nonzero_mask, a_vec, log_ratio_a_vec, sum_a_vec); + sum_b_vec = _mm512_maskz_fmadd_ps(nonzero_mask, b_vec, log_ratio_b_vec, sum_b_vec); + if (n) + goto simsimd_js_f32_skylake_cycle; + + simsimd_f32_t log2_normalizer = 0.693147181f; + *result = _mm512_reduce_add_ps(_mm512_add_ps(sum_a_vec, sum_b_vec)) * 0.5f * log2_normalizer; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_HASWELL + +#if SIMSIMD_TARGET_SAPPHIRE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512fp16") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512fp16"))), apply_to = function) + +inline __m512h simsimd_log2_f16_sapphire(__m512h x) { + // Extract the exponent and mantissa + __m512h one = _mm512_set1_ph((simsimd_f16_t)1); + __m512h e = _mm512_getexp_ph(x); + __m512h m = _mm512_getmant_ph(x, _MM_MANT_NORM_1_2, _MM_MANT_SIGN_src); + + // Compute the polynomial using Horner's method + __m512h p = _mm512_set1_ph((simsimd_f16_t)-3.4436006e-2f); + p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((simsimd_f16_t)3.1821337e-1f)); + p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((simsimd_f16_t)-1.2315303f)); + p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((simsimd_f16_t)2.5988452f)); + p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((simsimd_f16_t)-3.3241990f)); + p = _mm512_fmadd_ph(m, p, _mm512_set1_ph((simsimd_f16_t)3.1157899f)); + + return _mm512_add_ph(_mm512_mul_ph(p, _mm512_sub_ph(m, one)), e); +} + +SIMSIMD_PUBLIC void simsimd_kl_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512h sum_vec = _mm512_setzero_ph(); + __m512h epsilon_vec = _mm512_set1_ph((simsimd_f16_t)SIMSIMD_F16_DIVISION_EPSILON); + __m512h a_vec, b_vec; + +simsimd_kl_f16_sapphire_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_add_ph(mask, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)), epsilon_vec); + b_vec = _mm512_maskz_add_ph(mask, _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)), epsilon_vec); + n = 0; + } else { + a_vec = _mm512_add_ph(_mm512_castsi512_ph(_mm512_loadu_epi16(a)), epsilon_vec); + b_vec = _mm512_add_ph(_mm512_castsi512_ph(_mm512_loadu_epi16(b)), epsilon_vec); + a += 32, b += 32, n -= 32; + } + __m512h ratio_vec = _mm512_div_ph(a_vec, b_vec); + __m512h log_ratio_vec = simsimd_log2_f16_sapphire(ratio_vec); + __m512h prod_vec = _mm512_mul_ph(a_vec, log_ratio_vec); + sum_vec = _mm512_add_ph(sum_vec, prod_vec); + if (n) + goto simsimd_kl_f16_sapphire_cycle; + + simsimd_f32_t log2_normalizer = 0.693147181f; + *result = _mm512_reduce_add_ph(sum_vec) * log2_normalizer; +} + +SIMSIMD_PUBLIC void simsimd_js_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512h sum_a_vec = _mm512_setzero_ph(); + __m512h sum_b_vec = _mm512_setzero_ph(); + __m512h epsilon_vec = _mm512_set1_ph((simsimd_f16_t)SIMSIMD_F16_DIVISION_EPSILON); + __m512h a_vec, b_vec; + +simsimd_js_f16_sapphire_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, a)); + b_vec = _mm512_castsi512_ph(_mm512_maskz_loadu_epi16(mask, b)); + n = 0; + } else { + a_vec = _mm512_castsi512_ph(_mm512_loadu_epi16(a)); + b_vec = _mm512_castsi512_ph(_mm512_loadu_epi16(b)); + a += 32, b += 32, n -= 32; + } + __m512h m_vec = _mm512_mul_ph(_mm512_add_ph(a_vec, b_vec), _mm512_set1_ph((simsimd_f16_t)0.5f)); + __mmask32 nonzero_mask_a = _mm512_cmp_ph_mask(a_vec, epsilon_vec, _CMP_GE_OQ); + __mmask32 nonzero_mask_b = _mm512_cmp_ph_mask(b_vec, epsilon_vec, _CMP_GE_OQ); + __mmask32 nonzero_mask = nonzero_mask_a & nonzero_mask_b; + __m512h m_recip_approx = _mm512_rcp_ph(m_vec); + __m512h ratio_a_vec = _mm512_mul_ph(a_vec, m_recip_approx); + __m512h ratio_b_vec = _mm512_mul_ph(b_vec, m_recip_approx); + __m512h log_ratio_a_vec = simsimd_log2_f16_sapphire(ratio_a_vec); + __m512h log_ratio_b_vec = simsimd_log2_f16_sapphire(ratio_b_vec); + sum_a_vec = _mm512_maskz_fmadd_ph(nonzero_mask, a_vec, log_ratio_a_vec, sum_a_vec); + sum_b_vec = _mm512_maskz_fmadd_ph(nonzero_mask, b_vec, log_ratio_b_vec, sum_b_vec); + if (n) + goto simsimd_js_f16_sapphire_cycle; + + simsimd_f32_t log2_normalizer = 0.693147181f; + *result = _mm512_reduce_add_ph(_mm512_add_ph(sum_a_vec, sum_b_vec)) * 0.5f * log2_normalizer; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SAPPHIRE +#endif // SIMSIMD_TARGET_X86 + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/src/include/simsimd/simsimd.h b/src/include/simsimd/simsimd.h new file mode 100644 index 0000000..5f8f98c --- /dev/null +++ b/src/include/simsimd/simsimd.h @@ -0,0 +1,1261 @@ +/** + * @file simsimd.h + * @brief SIMD-accelerated Similarity Measures and Distance Functions. + * @author Ash Vardanian + * @date March 14, 2023 + * @copyright Copyright (c) 2023 + * + * References: + * x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/ + * Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/ + * Detecting target CPU features at compile time: https://stackoverflow.com/a/28939692/2766161 + */ + +#ifndef SIMSIMD_H +#define SIMSIMD_H + +#define SIMSIMD_VERSION_MAJOR 4 +#define SIMSIMD_VERSION_MINOR 3 +#define SIMSIMD_VERSION_PATCH 1 + +/** + * @brief Removes compile-time dispatching, and replaces it with runtime dispatching. + * So the `simsimd_dot_f32` function will invoke the most advanced backend supported by the CPU, + * that runs the program, rather than the most advanced backend supported by the CPU + * used to compile the library or the downstream application. + */ +#ifndef SIMSIMD_DYNAMIC_DISPATCH +#define SIMSIMD_DYNAMIC_DISPATCH (0) // true or false +#endif + +#include "binary.h" // Hamming, Jaccard +#include "dot.h" // Inner (dot) product, and its conjugate +#include "geospatial.h" // Haversine and Vincenty +#include "probability.h" // Kullback-Leibler, Jensen–Shannon +#include "spatial.h" // L2, Cosine + +#if SIMSIMD_TARGET_ARM +#ifdef __linux__ +#include +#include +#endif +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * @brief Enumeration of supported metric kinds. + * Some have aliases for convenience. + */ +typedef enum { + simsimd_metric_unknown_k = 0, ///< Unknown metric kind + + // Classics: + simsimd_metric_dot_k = 'i', ///< Inner product + simsimd_metric_inner_k = 'i', ///< Inner product alias + + simsimd_metric_vdot_k = 'v', ///< Complex inner product + + simsimd_metric_cos_k = 'c', ///< Cosine similarity + simsimd_metric_cosine_k = 'c', ///< Cosine similarity alias + simsimd_metric_angular_k = 'c', ///< Cosine similarity alias + + simsimd_metric_l2sq_k = 'e', ///< Squared Euclidean distance + simsimd_metric_sqeuclidean_k = 'e', ///< Squared Euclidean distance alias + + // Binary: + simsimd_metric_hamming_k = 'h', ///< Hamming distance + simsimd_metric_manhattan_k = 'h', ///< Manhattan distance is same as Hamming + + simsimd_metric_jaccard_k = 'j', ///< Jaccard coefficient + simsimd_metric_tanimoto_k = 'j', ///< Tanimoto coefficient is same as Jaccard + + // Probability: + simsimd_metric_kl_k = 'k', ///< Kullback-Leibler divergence + simsimd_metric_kullback_leibler_k = 'k', ///< Kullback-Leibler divergence alias + + simsimd_metric_js_k = 's', ///< Jensen-Shannon divergence + simsimd_metric_jensen_shannon_k = 's', ///< Jensen-Shannon divergence alias + +} simsimd_metric_kind_t; + +/** + * @brief Enumeration of SIMD capabilities of the target architecture. + */ +typedef enum { + simsimd_cap_serial_k = 1, ///< Serial (non-SIMD) capability + simsimd_cap_any_k = 0x7FFFFFFF, ///< Mask representing any capability with `INT_MAX` + + simsimd_cap_neon_k = 1 << 10, ///< ARM NEON capability + simsimd_cap_sve_k = 1 << 11, ///< ARM SVE capability + simsimd_cap_sve2_k = 1 << 12, ///< ARM SVE2 capability + + simsimd_cap_haswell_k = 1 << 20, ///< x86 AVX2 capability with FMA and F16C extensions + simsimd_cap_skylake_k = 1 << 21, ///< x86 AVX512 baseline capability + simsimd_cap_ice_k = 1 << 22, ///< x86 AVX512 capability with advanced integer algos + simsimd_cap_sapphire_k = 1 << 23, ///< x86 AVX512 capability with `f16` support + simsimd_cap_genoa_k = 1 << 24, ///< x86 AVX512 capability with `bf16` support + +} simsimd_capability_t; + +/** + * @brief Enumeration of supported data types. + * + * Includes complex type descriptors which in C code would use the real counterparts, + * but the independent flags contain metadata to be passed between programming language + * interfaces. + */ +typedef enum { + simsimd_datatype_unknown_k, ///< Unknown data type + simsimd_datatype_f64_k, ///< Double precision floating point + simsimd_datatype_f32_k, ///< Single precision floating point + simsimd_datatype_f16_k, ///< Half precision floating point + simsimd_datatype_bf16_k, ///< Brain floating point + simsimd_datatype_i8_k, ///< 8-bit integer + simsimd_datatype_b8_k, ///< Single-bit values packed into 8-bit words + + simsimd_datatype_f64c_k, ///< Complex double precision floating point + simsimd_datatype_f32c_k, ///< Complex single precision floating point + simsimd_datatype_f16c_k, ///< Complex half precision floating point + simsimd_datatype_bf16c_k, ///< Complex brain floating point +} simsimd_datatype_t; + +/** + * @brief Type-punned function pointer accepting two vectors and outputting their similarity/distance. + * + * @param[in] a Pointer to the first data array. + * @param[in] b Pointer to the second data array. + * @param[in] n Number of scalar words in the input arrays. + * @param[out] d Output value as a double-precision float. + * In complex dot-products @b two double-precision scalars are exported + * for the real and imaginary parts. + */ +typedef void (*simsimd_metric_punned_t)(void const* a, void const* b, simsimd_size_t n, simsimd_distance_t* d); + +#if SIMSIMD_DYNAMIC_DISPATCH +SIMSIMD_DYNAMIC simsimd_capability_t simsimd_capabilities(void); +#else +SIMSIMD_PUBLIC simsimd_capability_t simsimd_capabilities(void); +#endif + +/** + * @brief Function to determine the SIMD capabilities of the current machine at @b runtime. + * @return A bitmask of the SIMD capabilities represented as a `simsimd_capability_t` enum value. + */ +SIMSIMD_PUBLIC simsimd_capability_t simsimd_capabilities_implementation(void) { + +#if SIMSIMD_TARGET_X86 + + /// The states of 4 registers populated for a specific "cpuid" assembly call + union four_registers_t { + int array[4]; + struct separate_t { + unsigned eax, ebx, ecx, edx; + } named; + } info1, info7, info7sub1; + +#ifdef _MSC_VER + __cpuidex(info1.array, 1, 0); + __cpuidex(info7.array, 7, 0); + __cpuidex(info7sub1.array, 7, 1); +#else + __asm__ __volatile__("cpuid" + : "=a"(info1.named.eax), "=b"(info1.named.ebx), "=c"(info1.named.ecx), "=d"(info1.named.edx) + : "a"(1), "c"(0)); + __asm__ __volatile__("cpuid" + : "=a"(info7.named.eax), "=b"(info7.named.ebx), "=c"(info7.named.ecx), "=d"(info7.named.edx) + : "a"(7), "c"(0)); + __asm__ __volatile__("cpuid" + : "=a"(info7sub1.named.eax), "=b"(info7sub1.named.ebx), "=c"(info7sub1.named.ecx), + "=d"(info7sub1.named.edx) + : "a"(7), "c"(1)); +#endif + + // Check for AVX2 (Function ID 7, EBX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L148 + unsigned supports_avx2 = (info7.named.ebx & 0x00000020) != 0; + // Check for F16C (Function ID 1, ECX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L107 + unsigned supports_f16c = (info1.named.ecx & 0x20000000) != 0; + unsigned supports_fma = (info1.named.ecx & 0x00001000) != 0; + // Check for AVX512F (Function ID 7, EBX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L155 + unsigned supports_avx512f = (info7.named.ebx & 0x00010000) != 0; + // Check for AVX512FP16 (Function ID 7, EDX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L198C9-L198C23 + unsigned supports_avx512fp16 = (info7.named.edx & 0x00800000) != 0; + // Check for AVX512VNNI (Function ID 7, ECX register) + unsigned supports_avx512vnni = (info7.named.ecx & 0x00000800) != 0; + // Check for AVX512IFMA (Function ID 7, EBX register) + unsigned supports_avx512ifma = (info7.named.ebx & 0x00200000) != 0; + // Check for AVX512BITALG (Function ID 7, ECX register) + unsigned supports_avx512bitalg = (info7.named.ecx & 0x00001000) != 0; + // Check for AVX512VBMI2 (Function ID 7, ECX register) + unsigned supports_avx512vbmi2 = (info7.named.ecx & 0x00000040) != 0; + // Check for AVX512VPOPCNTDQ (Function ID 7, ECX register) + unsigned supports_avx512vpopcntdq = (info7.named.ecx & 0x00004000) != 0; + // Check for AVX512BF16 (Function ID 7, Sub-leaf 1, EAX register) + // https://github.com/llvm/llvm-project/blob/50598f0ff44f3a4e75706f8c53f3380fe7faa896/clang/lib/Headers/cpuid.h#L205 + unsigned supports_avx512bf16 = (info7sub1.named.eax & 0x00000020) != 0; + + // Convert specific features into CPU generations + unsigned supports_haswell = supports_avx2 && supports_f16c && supports_fma; + unsigned supports_skylake = supports_avx512f; + unsigned supports_ice = supports_avx512vnni && supports_avx512ifma && supports_avx512bitalg && + supports_avx512vbmi2 && supports_avx512vpopcntdq; + unsigned supports_genoa = supports_avx512bf16; + unsigned supports_sapphire = supports_avx512fp16; + + return (simsimd_capability_t)( // + (simsimd_cap_haswell_k * supports_haswell) | // + (simsimd_cap_skylake_k * supports_skylake) | // + (simsimd_cap_ice_k * supports_ice) | // + (simsimd_cap_genoa_k * supports_genoa) | // + (simsimd_cap_sapphire_k * supports_sapphire) | // + (simsimd_cap_serial_k)); + +#endif // SIMSIMD_TARGET_X86 + +#if SIMSIMD_TARGET_ARM + + // Every 64-bit Arm CPU supports NEON + unsigned supports_neon = 1; + unsigned supports_sve = 0; + unsigned supports_sve2 = 0; + +#ifdef __linux__ + unsigned long hwcap = getauxval(AT_HWCAP); + unsigned long hwcap2 = getauxval(AT_HWCAP2); + supports_sve = (hwcap & HWCAP_SVE) != 0; + supports_sve2 = (hwcap2 & HWCAP2_SVE2) != 0; +#endif + + return (simsimd_capability_t)( // + (simsimd_cap_neon_k * supports_neon) | // + (simsimd_cap_sve_k * supports_sve) | // + (simsimd_cap_sve2_k * supports_sve2) | // + (simsimd_cap_serial_k)); + +#endif // SIMSIMD_TARGET_ARM + + return simsimd_cap_serial_k; +} + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wcast-function-type" +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wcast-function-type" + +/** + * @brief Determines the best suited metric implementation based on the given datatype, + * supported and allowed by hardware capabilities. + * + * @param kind The kind of metric to be evaluated. + * @param datatype The data type for which the metric needs to be evaluated. + * @param supported The hardware capabilities supported by the CPU. + * @param allowed The hardware capabilities allowed for use. + * @param metric_output Output variable for the selected similarity function. + * @param capability_output Output variable for the utilized hardware capabilities. + */ +SIMSIMD_PUBLIC void simsimd_find_metric_punned( // + simsimd_metric_kind_t kind, // + simsimd_datatype_t datatype, // + simsimd_capability_t supported, // + simsimd_capability_t allowed, // + simsimd_metric_punned_t* metric_output, // + simsimd_capability_t* capability_output) { + + simsimd_metric_punned_t* m = metric_output; + simsimd_capability_t* c = capability_output; + simsimd_capability_t viable = (simsimd_capability_t)(supported & allowed); + *m = (simsimd_metric_punned_t)0; + *c = (simsimd_capability_t)0; + + typedef simsimd_metric_punned_t m_t; + switch (datatype) { + + case simsimd_datatype_unknown_k: break; + + // Double-precision floating-point vectors + case simsimd_datatype_f64_k: + +#if SIMSIMD_TARGET_SVE + if (viable & simsimd_cap_skylake_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_sve, *c = simsimd_cap_sve_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SKYLAKE + if (viable & simsimd_cap_skylake_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_skylake, *c = simsimd_cap_skylake_k; return; + default: break; + } +#endif + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f64_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f64_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + + // Single-precision floating-point vectors + case simsimd_datatype_f32_k: + +#if SIMSIMD_TARGET_SVE + if (viable & simsimd_cap_sve_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_sve, *c = simsimd_cap_sve_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_NEON + if (viable & simsimd_cap_neon_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_neon, *c = simsimd_cap_neon_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SKYLAKE + if (viable & simsimd_cap_skylake_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_skylake, *c = simsimd_cap_skylake_k; return; + default: break; + } +#endif + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f32_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f32_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + + // Half-precision floating-point vectors + case simsimd_datatype_f16_k: + +#if SIMSIMD_TARGET_SVE + if (viable & simsimd_cap_sve_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_sve, *c = simsimd_cap_sve_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_NEON + if (viable & simsimd_cap_neon_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_neon, *c = simsimd_cap_neon_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SAPPHIRE + if (viable & simsimd_cap_sapphire_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_sapphire, *c = simsimd_cap_sapphire_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_HASWELL + if (viable & simsimd_cap_haswell_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_f16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_f16_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + + // Brain floating-point vectors + case simsimd_datatype_bf16_k: +#if SIMSIMD_TARGET_HASWELL + if (viable & simsimd_cap_haswell_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_GENOA + if (viable & simsimd_cap_genoa_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_genoa, *c = simsimd_cap_genoa_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_genoa, *c = simsimd_cap_genoa_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_genoa, *c = simsimd_cap_genoa_k; return; + default: break; + } +#endif + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_js_k: *m = (m_t)&simsimd_js_bf16_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_kl_k: *m = (m_t)&simsimd_kl_bf16_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + + // Single-byte integer vectors + case simsimd_datatype_i8_k: +#if SIMSIMD_TARGET_NEON + if (viable & simsimd_cap_neon_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_neon, *c = simsimd_cap_neon_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_ICE + if (viable & simsimd_cap_ice_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_ice, *c = simsimd_cap_ice_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_HASWELL + if (viable & simsimd_cap_haswell_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_i8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_cos_k: *m = (m_t)&simsimd_cos_i8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_l2sq_k: *m = (m_t)&simsimd_l2sq_i8_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + + // Binary vectors + case simsimd_datatype_b8_k: + +#if SIMSIMD_TARGET_SVE + if (viable & simsimd_cap_sve_k) + switch (kind) { + case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_sve, *c = simsimd_cap_sve_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_NEON + if (viable & simsimd_cap_neon_k) + switch (kind) { + case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_neon, *c = simsimd_cap_neon_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_ICE + if (viable & simsimd_cap_ice_k) + switch (kind) { + case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_ice, *c = simsimd_cap_ice_k; return; + case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_ice, *c = simsimd_cap_ice_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_HASWELL + if (viable & simsimd_cap_haswell_k) + switch (kind) { + case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_hamming_k: *m = (m_t)&simsimd_hamming_b8_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_jaccard_k: *m = (m_t)&simsimd_jaccard_b8_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + + case simsimd_datatype_f32c_k: + +#if SIMSIMD_TARGET_SVE + if (viable & simsimd_cap_sve_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_sve, *c = simsimd_cap_sve_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_NEON + if (viable & simsimd_cap_neon_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_neon, *c = simsimd_cap_neon_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SKYLAKE + if (viable & simsimd_cap_skylake_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_skylake, *c = simsimd_cap_skylake_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_HASWELL + if (viable & simsimd_cap_haswell_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f32c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f32c_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + + case simsimd_datatype_f64c_k: + +#if SIMSIMD_TARGET_SVE + if (viable & simsimd_cap_sve_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64c_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f64c_sve, *c = simsimd_cap_sve_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SKYLAKE + if (viable & simsimd_cap_skylake_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64c_skylake, *c = simsimd_cap_skylake_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f64c_skylake, *c = simsimd_cap_skylake_k; return; + default: break; + } +#endif + + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f64c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f64c_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + + case simsimd_datatype_f16c_k: + +#if SIMSIMD_TARGET_SVE + if (viable & simsimd_cap_sve_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_sve, *c = simsimd_cap_sve_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_sve, *c = simsimd_cap_sve_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_NEON + if (viable & simsimd_cap_neon_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_neon, *c = simsimd_cap_neon_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_neon, *c = simsimd_cap_neon_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_SAPPHIRE + if (viable & simsimd_cap_sapphire_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_sapphire, *c = simsimd_cap_sapphire_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_sapphire, *c = simsimd_cap_sapphire_k; return; + default: break; + } +#endif +#if SIMSIMD_TARGET_HASWELL + if (viable & simsimd_cap_haswell_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_haswell, *c = simsimd_cap_haswell_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_haswell, *c = simsimd_cap_haswell_k; return; + default: break; + } +#endif + + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_f16c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_f16c_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + case simsimd_datatype_bf16c_k: + +#if SIMSIMD_TARGET_GENOA + if (viable & simsimd_cap_genoa_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16c_genoa, *c = simsimd_cap_genoa_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_bf16c_genoa, *c = simsimd_cap_genoa_k; return; + default: break; + } +#endif + + if (viable & simsimd_cap_serial_k) + switch (kind) { + case simsimd_metric_dot_k: *m = (m_t)&simsimd_dot_bf16c_serial, *c = simsimd_cap_serial_k; return; + case simsimd_metric_vdot_k: *m = (m_t)&simsimd_vdot_bf16c_serial, *c = simsimd_cap_serial_k; return; + default: break; + } + + break; + } +} + +#pragma clang diagnostic pop +#pragma GCC diagnostic pop + +/** + * @brief Selects the most suitable metric implementation based on the given metric kind, datatype, + * and allowed capabilities. @b Don't call too often and prefer caching the `simsimd_capabilities()`. + * + * @param kind The kind of metric to be evaluated. + * @param datatype The data type for which the metric needs to be evaluated. + * @param allowed The hardware capabilities allowed for use. + * @return A function pointer to the selected metric implementation. + */ +SIMSIMD_PUBLIC simsimd_metric_punned_t simsimd_metric_punned( // + simsimd_metric_kind_t kind, // + simsimd_datatype_t datatype, // + simsimd_capability_t allowed) { + + simsimd_metric_punned_t result = 0; + simsimd_capability_t c = simsimd_cap_serial_k; + simsimd_capability_t supported = simsimd_capabilities(); + simsimd_find_metric_punned(kind, datatype, supported, allowed, &result, &c); + return result; +} + +#if SIMSIMD_DYNAMIC_DISPATCH + +/* Run-time feature-testing functions + * - Check if the CPU supports NEON or SVE extensions on Arm + * - Check if the CPU supports AVX2 and F16C extensions on Haswell x86 CPUs and newer + * - Check if the CPU supports AVX512F and AVX512BW extensions on Skylake x86 CPUs and newer + * - Check if the CPU supports AVX512VNNI, AVX512IFMA, AVX512BITALG, AVX512VBMI2, and AVX512VPOPCNTDQ + * extensions on Ice Lake x86 CPUs and newer + * - Check if the CPU supports AVX512FP16 extensions on Sapphire Rapids x86 CPUs and newer + * + * @return 1 if the CPU supports the SIMD instruction set, 0 otherwise. + */ +SIMSIMD_DYNAMIC int simsimd_uses_neon(void); +SIMSIMD_DYNAMIC int simsimd_uses_sve(void); +SIMSIMD_DYNAMIC int simsimd_uses_haswell(void); +SIMSIMD_DYNAMIC int simsimd_uses_skylake(void); +SIMSIMD_DYNAMIC int simsimd_uses_ice(void); +SIMSIMD_DYNAMIC int simsimd_uses_sapphire(void); +SIMSIMD_DYNAMIC int simsimd_uses_genoa(void); +SIMSIMD_DYNAMIC simsimd_capability_t simsimd_capabilities(void); + +/* Inner products + * - Dot product: the sum of the products of the corresponding elements of two vectors. + * - Complex Dot product: dot product with a conjugate first argument. + * - Complex Conjugate Dot product: dot product with a conjugate first argument. + * + * @param a The first vector. + * @param b The second vector. + * @param n The number of elements in the vectors. Even for complex variants (the number of scalars). + * @param d The output distance value. + * + * @note The dot product can be negative, to use as a distance, take `1 - a * b`. + * @note The dot product is zero if and only if the two vectors are orthogonal. + * @note Defined only for floating-point and integer data types. + */ +SIMSIMD_DYNAMIC void simsimd_dot_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_dot_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_dot_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_dot_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_dot_f16c(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_dot_bf16c(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_dot_f32c(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_dot_f64c(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_vdot_f16c(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_vdot_bf16c(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_vdot_f32c(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_vdot_f64c(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d); + +/* Spatial distances + * - Cosine distance: the cosine of the angle between two vectors. + * - L2 squared distance: the squared Euclidean distance between two vectors. + * + * @param a The first vector. + * @param b The second vector. + * @param n The number of elements in the vectors. + * @param d The output distance value. + * + * @note The output distance value is non-negative. + * @note The output distance value is zero if and only if the two vectors are identical. + * @note Defined only for floating-point and integer data types. + */ +SIMSIMD_DYNAMIC void simsimd_cos_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_cos_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_cos_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_cos_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_cos_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_l2sq_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_l2sq_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_l2sq_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_l2sq_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_l2sq_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d); + +/* Binary distances + * - Hamming distance: the number of positions at which the corresponding bits are different. + * - Jaccard distance: ratio of bit-level matching positions (intersection) to the total number of positions (union). + * + * @param a The first binary vector. + * @param b The second binary vector. + * @param n The number of 8-bit words in the vectors. + * @param d The output distance value. + * + * @note The output distance value is non-negative. + * @note The output distance value is zero if and only if the two vectors are identical. + * @note Defined only for binary data. + */ +SIMSIMD_DYNAMIC void simsimd_hamming_b8(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_jaccard_b8(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n, + simsimd_distance_t* d); + +/* Probability distributions + * - Jensen-Shannon divergence: a measure of similarity between two probability distributions. + * - Kullback-Leibler divergence: a measure of how one probability distribution diverges from a second. + * + * @param a The first descrete probability distribution. + * @param b The second descrete probability distribution. + * @param n The number of elements in the descrete distributions. + * @param d The output divergence value. + * + * @note The distributions are assumed to be normalized. + * @note The output divergence value is non-negative. + * @note The output divergence value is zero if and only if the two distributions are identical. + * @note Defined only for floating-point data types. + */ +SIMSIMD_DYNAMIC void simsimd_kl_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_kl_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_kl_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_kl_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_js_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_js_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_js_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d); +SIMSIMD_DYNAMIC void simsimd_js_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d); + +#else + +/* Compile-time feature-testing functions + * - Check if the CPU supports NEON or SVE extensions on Arm + * - Check if the CPU supports AVX2 and F16C extensions on Haswell x86 CPUs and newer + * - Check if the CPU supports AVX512F and AVX512BW extensions on Skylake x86 CPUs and newer + * - Check if the CPU supports AVX512VNNI, AVX512IFMA, AVX512BITALG, AVX512VBMI2, and AVX512VPOPCNTDQ + * extensions on Ice Lake x86 CPUs and newer + * - Check if the CPU supports AVX512BF16 extensions on Genoa x86 CPUs and newer + * - Check if the CPU supports AVX512FP16 extensions on Sapphire Rapids x86 CPUs and newer + * + * @return 1 if the CPU supports the SIMD instruction set, 0 otherwise. + */ +SIMSIMD_PUBLIC int simsimd_uses_neon(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_NEON; } +SIMSIMD_PUBLIC int simsimd_uses_sve(void) { return SIMSIMD_TARGET_ARM && SIMSIMD_TARGET_SVE; } +SIMSIMD_PUBLIC int simsimd_uses_haswell(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_HASWELL; } +SIMSIMD_PUBLIC int simsimd_uses_skylake(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_SKYLAKE; } +SIMSIMD_PUBLIC int simsimd_uses_ice(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_ICE; } +SIMSIMD_PUBLIC int simsimd_uses_sapphire(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_SAPPHIRE; } +SIMSIMD_PUBLIC int simsimd_uses_genoa(void) { return SIMSIMD_TARGET_X86 && SIMSIMD_TARGET_GENOA; } +SIMSIMD_PUBLIC simsimd_capability_t simsimd_capabilities(void) { return simsimd_capabilities_implementation(); } + +/* Inner products + * - Dot product: the sum of the products of the corresponding elements of two vectors. + * - Complex Dot product: dot product with a conjugate first argument. + * - Complex Conjugate Dot product: dot product with a conjugate first argument. + * + * @param a The first vector. + * @param b The second vector. + * @param n The number of elements in the vectors. Even for complex variants (the number of scalars). + * @param d The output distance value. + * + * @note The dot product can be negative, to use as a distance, take `1 - a * b`. + * @note The dot product is zero if and only if the two vectors are orthogonal. + * @note Defined only for floating-point and integer data types. + */ +SIMSIMD_PUBLIC void simsimd_dot_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_dot_f16_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_dot_f16_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SAPPHIRE + simsimd_dot_f16_sapphire(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_dot_f16_haswell(a, b, n, d); +#else + simsimd_dot_f16_serial(a, b, n, d); +#endif +} + +SIMSIMD_PUBLIC void simsimd_dot_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_GENOA + simsimd_dot_bf16_genoa(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_dot_bf16_haswell(a, b, n, d); +#else + simsimd_dot_bf16_serial(a, b, n, d); +#endif +} + +SIMSIMD_PUBLIC void simsimd_dot_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_dot_f32_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_dot_f32_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_dot_f32_skylake(a, b, n, d); +#else + simsimd_dot_f32_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_dot_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_dot_f64_sve(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_dot_f64_skylake(a, b, n, d); +#else + simsimd_dot_f64_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_dot_f16c(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_dot_f16c_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_dot_f16c_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SAPPHIRE + simsimd_dot_f16c_sapphire(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_dot_f16c_haswell(a, b, n, d); +#else + simsimd_dot_f16c_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_dot_bf16c(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_GENOA + simsimd_dot_bf16c_genoa(a, b, n, d); +#else + simsimd_dot_bf16c_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_dot_f32c(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_dot_f32c_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_dot_f32c_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_dot_f32c_skylake(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_dot_f32c_haswell(a, b, n, d); +#else + simsimd_dot_f32c_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_dot_f64c(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_dot_f64c_sve(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_dot_f64c_skylake(a, b, n, d); +#else + simsimd_dot_f64c_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_vdot_f16c(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_vdot_f16c_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_dot_f16c_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SAPPHIRE + simsimd_dot_f16c_sapphire(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_dot_f16c_haswell(a, b, n, d); +#else + simsimd_vdot_f16c_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_vdot_bf16c(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { + simsimd_vdot_bf16c_serial(a, b, n, d); +} +SIMSIMD_PUBLIC void simsimd_vdot_f32c(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_vdot_f32c_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_dot_f32c_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_dot_f32c_skylake(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_dot_f32c_haswell(a, b, n, d); +#else + simsimd_vdot_f32c_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_vdot_f64c(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_vdot_f64c_sve(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_vdot_f64c_skylake(a, b, n, d); +#else + simsimd_vdot_f64c_serial(a, b, n, d); +#endif +} + +/* Spatial distances + * - Cosine distance: the cosine of the angle between two vectors. + * - L2 squared distance: the squared Euclidean distance between two vectors. + * + * @param a The first vector. + * @param b The second vector. + * @param n The number of elements in the vectors. + * @param d The output distance value. + * + * @note The output distance value is non-negative. + * @note The output distance value is zero if and only if the two vectors are identical. + * @note Defined only for floating-point and integer data types. + */ +SIMSIMD_PUBLIC void simsimd_cos_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON + simsimd_cos_i8_neon(a, b, n, d); +#elif SIMSIMD_TARGET_ICE + simsimd_cos_i8_ice(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_cos_i8_haswell(a, b, n, d); +#else + simsimd_cos_i8_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_l2sq_i8(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON + simsimd_l2sq_i8_neon(a, b, n, d); +#elif SIMSIMD_TARGET_ICE + simsimd_l2sq_i8_ice(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_l2sq_i8_haswell(a, b, n, d); +#else + simsimd_l2sq_i8_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_cos_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_cos_f16_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_cos_f16_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SAPPHIRE + simsimd_cos_f16_sapphire(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_cos_f16_haswell(a, b, n, d); +#else + simsimd_cos_f16_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_cos_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_GENOA + simsimd_cos_bf16_genoa(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_cos_bf16_haswell(a, b, n, d); +#else + simsimd_cos_bf16_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_cos_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_cos_f32_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_cos_f32_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_cos_f32_skylake(a, b, n, d); +#else + simsimd_cos_f32_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_cos_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_cos_f64_sve(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_cos_f64_skylake(a, b, n, d); +#else + simsimd_cos_f64_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_l2sq_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_l2sq_f16_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_l2sq_f16_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SAPPHIRE + simsimd_l2sq_f16_sapphire(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_l2sq_f16_haswell(a, b, n, d); +#else + simsimd_l2sq_f16_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_l2sq_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_GENOA + simsimd_l2sq_bf16_genoa(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_l2sq_bf16_haswell(a, b, n, d); +#else + simsimd_l2sq_bf16_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_l2sq_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_l2sq_f32_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_l2sq_f32_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_l2sq_f32_skylake(a, b, n, d); +#else + simsimd_l2sq_f32_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_l2sq_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_l2sq_f64_sve(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_l2sq_f64_skylake(a, b, n, d); +#else + simsimd_l2sq_f64_serial(a, b, n, d); +#endif +} + +/* Binary distances + * - Hamming distance: the number of positions at which the corresponding bits are different. + * - Jaccard distance: ratio of bit-level matching positions (intersection) to the total number of positions (union). + * + * @param a The first binary vector. + * @param b The second binary vector. + * @param n The number of 8-bit words in the vectors. + * @param d The output distance value. + * + * @note The output distance value is non-negative. + * @note The output distance value is zero if and only if the two vectors are identical. + * @note Defined only for binary data. + */ +SIMSIMD_PUBLIC void simsimd_hamming_b8(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_hamming_b8_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_hamming_b8_neon(a, b, n, d); +#elif SIMSIMD_TARGET_ICE + simsimd_hamming_b8_ice(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_hamming_b8_haswell(a, b, n, d); +#else + simsimd_hamming_b8_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_jaccard_b8(simsimd_b8_t const* a, simsimd_b8_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_SVE + simsimd_jaccard_b8_sve(a, b, n, d); +#elif SIMSIMD_TARGET_NEON + simsimd_jaccard_b8_neon(a, b, n, d); +#elif SIMSIMD_TARGET_ICE + simsimd_jaccard_b8_ice(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_jaccard_b8_haswell(a, b, n, d); +#else + simsimd_jaccard_b8_serial(a, b, n, d); +#endif +} + +/* Probability distributions + * - Jensen-Shannon divergence: a measure of similarity between two probability distributions. + * - Kullback-Leibler divergence: a measure of how one probability distribution diverges from a second. + * + * @param a The first descrete probability distribution. + * @param b The second descrete probability distribution. + * @param n The number of elements in the descrete distributions. + * @param d The output divergence value. + * + * @note The distributions are assumed to be normalized. + * @note The output divergence value is non-negative. + * @note The output divergence value is zero if and only if the two distributions are identical. + * @note Defined only for floating-point data types. + */ +SIMSIMD_PUBLIC void simsimd_kl_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON + simsimd_kl_f16_neon(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_kl_f16_haswell(a, b, n, d); +#else + simsimd_kl_f16_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_kl_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { + simsimd_kl_bf16_serial(a, b, n, d); +} +SIMSIMD_PUBLIC void simsimd_kl_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON + simsimd_kl_f32_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_kl_f32_skylake(a, b, n, d); +#else + simsimd_kl_f32_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_kl_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { + simsimd_kl_f64_serial(a, b, n, d); +} +SIMSIMD_PUBLIC void simsimd_js_f16(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON + simsimd_js_f16_neon(a, b, n, d); +#elif SIMSIMD_TARGET_HASWELL + simsimd_js_f16_haswell(a, b, n, d); +#else + simsimd_js_f16_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_js_bf16(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { + simsimd_js_bf16_serial(a, b, n, d); +} +SIMSIMD_PUBLIC void simsimd_js_f32(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { +#if SIMSIMD_TARGET_NEON + simsimd_js_f32_neon(a, b, n, d); +#elif SIMSIMD_TARGET_SKYLAKE + simsimd_js_f32_skylake(a, b, n, d); +#else + simsimd_js_f32_serial(a, b, n, d); +#endif +} +SIMSIMD_PUBLIC void simsimd_js_f64(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* d) { + simsimd_js_f64_serial(a, b, n, d); +} + +#endif + +#ifdef __cplusplus +} +#endif + +#endif // SIMSIMD_H diff --git a/src/include/simsimd/spatial.h b/src/include/simsimd/spatial.h new file mode 100644 index 0000000..0b51298 --- /dev/null +++ b/src/include/simsimd/spatial.h @@ -0,0 +1,1341 @@ +/** + * @file spatial.h + * @brief SIMD-accelerated Spatial Similarity Measures. + * @author Ash Vardanian + * @date March 14, 2023 + * + * Contains: + * - L2 (Euclidean) squared distance + * - Cosine (Angular) similarity + * + * For datatypes: + * - 64-bit IEEE floating point numbers + * - 32-bit IEEE floating point numbers + * - 16-bit IEEE floating point numbers + * - 16-bit brain floating point numbers + * - 8-bit signed integral numbers + * + * For hardware architectures: + * - Arm (NEON, SVE) + * - x86 (AVX2, AVX512) + * + * x86 intrinsics: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/ + * Arm intrinsics: https://developer.arm.com/architectures/instruction-sets/intrinsics/ + */ +#ifndef SIMSIMD_SPATIAL_H +#define SIMSIMD_SPATIAL_H + +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// clang-format off + +/* Serial backends for all numeric types. + * By default they use 32-bit arithmetic, unless the arguments themselves contain 64-bit floats. + * For double-precision computation check out the "*_accurate" variants of those "*_serial" functions. + */ +SIMSIMD_PUBLIC void simsimd_l2sq_f64_serial(simsimd_f64_t const* a, simsimd_f64_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_f64_serial(simsimd_f64_t const* a, simsimd_f64_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_f32_serial(simsimd_f32_t const* a, simsimd_f32_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_f32_serial(simsimd_f32_t const* a, simsimd_f32_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_f16_serial(simsimd_f16_t const* a, simsimd_f16_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_f16_serial(simsimd_f16_t const* a, simsimd_f16_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_serial(simsimd_bf16_t const* a, simsimd_bf16_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_bf16_serial(simsimd_bf16_t const* a, simsimd_bf16_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_i8_serial(simsimd_i8_t const* a, simsimd_i8_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_i8_serial(simsimd_i8_t const* a, simsimd_i8_t const*, simsimd_size_t n, simsimd_distance_t* d); + +/* Double-precision serial backends for all numeric types. + * For single-precision computation check out the "*_serial" counterparts of those "*_accurate" functions. + */ +SIMSIMD_PUBLIC void simsimd_l2sq_f32_accurate(simsimd_f32_t const* a, simsimd_f32_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_f32_accurate(simsimd_f32_t const* a, simsimd_f32_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_f16_accurate(simsimd_f16_t const* a, simsimd_f16_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_f16_accurate(simsimd_f16_t const* a, simsimd_f16_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_bf16_accurate(simsimd_bf16_t const* a, simsimd_bf16_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_i8_accurate(simsimd_i8_t const* a, simsimd_i8_t const*, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_i8_accurate(simsimd_i8_t const* a, simsimd_i8_t const*, simsimd_size_t n, simsimd_distance_t* d); + +/* SIMD-powered backends for Arm NEON, mostly using 32-bit arithmetic over 128-bit words. + * By far the most portable backend, covering most Arm v8 devices, over a billion phones, and almost all + * server CPUs produced before 2023. + */ +SIMSIMD_PUBLIC void simsimd_l2sq_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_l2sq_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); +SIMSIMD_PUBLIC void simsimd_cos_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t* d); + +/* SIMD-powered backends for Arm SVE, mostly using 32-bit arithmetic over variable-length platform-defined word sizes. + * Designed for Arm Graviton 3, Microsoft Cobalt, as well as Nvidia Grace and newer Ampere Altra CPUs. + */ +SIMSIMD_PUBLIC void simsimd_l2sq_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_l2sq_f16_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_f16_sve(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_l2sq_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t*); + +/* SIMD-powered backends for AVX2 CPUs of Haswell generation and newer, using 32-bit arithmetic over 256-bit words. + * First demonstrated in 2011, at least one Haswell-based processor was still being sold in 2022 — the Pentium G3420. + * Practically all modern x86 CPUs support AVX2, FMA, and F16C, making it a perfect baseline for SIMD algorithms. + * On other hand, there is no need to implement AVX2 versions of `f32` and `f64` functions, as those are + * properly vectorized by recent compilers. + */ +SIMSIMD_PUBLIC void simsimd_l2sq_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_l2sq_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t*); + +/* SIMD-powered backends for AVX512 CPUs of Skylake generation and newer, using 32-bit arithmetic over 512-bit words. + * Skylake was launched in 2015, and discontinued in 2019. Skylake had support for F, CD, VL, DQ, and BW extensions, + * as well as masked operations. This is enough to supersede auto-vectorization on `f32` and `f64` types. + */ +SIMSIMD_PUBLIC void simsimd_l2sq_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_l2sq_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, simsimd_distance_t*); + +/* SIMD-powered backends for AVX512 CPUs of Ice Lake generation and newer, using mixed arithmetic over 512-bit words. + * Ice Lake added VNNI, VPOPCNTDQ, IFMA, VBMI, VAES, GFNI, VBMI2, BITALG, VPCLMULQDQ, and other extensions for integral operations. + * Sapphire Rapids added tiled matrix operations, but we are most interested in the new mixed-precision FMA instructions. + */ +SIMSIMD_PUBLIC void simsimd_l2sq_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_l2sq_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t*); +SIMSIMD_PUBLIC void simsimd_cos_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, simsimd_distance_t*); + +// clang-format on + +#define SIMSIMD_MAKE_L2SQ(name, input_type, accumulator_type, converter) \ + SIMSIMD_PUBLIC void simsimd_l2sq_##input_type##_##name(simsimd_##input_type##_t const* a, \ + simsimd_##input_type##_t const* b, simsimd_size_t n, \ + simsimd_distance_t* result) { \ + simsimd_##accumulator_type##_t d2 = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = converter(a[i]); \ + simsimd_##accumulator_type##_t bi = converter(b[i]); \ + d2 += (ai - bi) * (ai - bi); \ + } \ + *result = d2; \ + } + +#define SIMSIMD_MAKE_COS(name, input_type, accumulator_type, converter) \ + SIMSIMD_PUBLIC void simsimd_cos_##input_type##_##name(simsimd_##input_type##_t const* a, \ + simsimd_##input_type##_t const* b, simsimd_size_t n, \ + simsimd_distance_t* result) { \ + simsimd_##accumulator_type##_t ab = 0, a2 = 0, b2 = 0; \ + for (simsimd_size_t i = 0; i != n; ++i) { \ + simsimd_##accumulator_type##_t ai = converter(a[i]); \ + simsimd_##accumulator_type##_t bi = converter(b[i]); \ + ab += ai * bi; \ + a2 += ai * ai; \ + b2 += bi * bi; \ + } \ + *result = ab != 0 ? (1 - ab * SIMSIMD_RSQRT(a2) * SIMSIMD_RSQRT(b2)) : 1; \ + } + +SIMSIMD_MAKE_L2SQ(serial, f64, f64, SIMSIMD_IDENTIFY) // simsimd_l2sq_f64_serial +SIMSIMD_MAKE_COS(serial, f64, f64, SIMSIMD_IDENTIFY) // simsimd_cos_f64_serial + +SIMSIMD_MAKE_L2SQ(serial, f32, f32, SIMSIMD_IDENTIFY) // simsimd_l2sq_f32_serial +SIMSIMD_MAKE_COS(serial, f32, f32, SIMSIMD_IDENTIFY) // simsimd_cos_f32_serial + +SIMSIMD_MAKE_L2SQ(serial, f16, f32, SIMSIMD_UNCOMPRESS_F16) // simsimd_l2sq_f16_serial +SIMSIMD_MAKE_COS(serial, f16, f32, SIMSIMD_UNCOMPRESS_F16) // simsimd_cos_f16_serial + +SIMSIMD_MAKE_L2SQ(serial, bf16, f32, SIMSIMD_UNCOMPRESS_BF16) // simsimd_l2sq_bf16_serial +SIMSIMD_MAKE_COS(serial, bf16, f32, SIMSIMD_UNCOMPRESS_BF16) // simsimd_cos_bf16_serial + +SIMSIMD_MAKE_L2SQ(serial, i8, i32, SIMSIMD_IDENTIFY) // simsimd_l2sq_i8_serial +SIMSIMD_MAKE_COS(serial, i8, i32, SIMSIMD_IDENTIFY) // simsimd_cos_i8_serial + +SIMSIMD_MAKE_L2SQ(accurate, f32, f64, SIMSIMD_IDENTIFY) // simsimd_l2sq_f32_accurate +SIMSIMD_MAKE_COS(accurate, f32, f64, SIMSIMD_IDENTIFY) // simsimd_cos_f32_accurate + +SIMSIMD_MAKE_L2SQ(accurate, f16, f64, SIMSIMD_UNCOMPRESS_F16) // simsimd_l2sq_f16_accurate +SIMSIMD_MAKE_COS(accurate, f16, f64, SIMSIMD_UNCOMPRESS_F16) // simsimd_cos_f16_accurate + +SIMSIMD_MAKE_L2SQ(accurate, bf16, f64, SIMSIMD_UNCOMPRESS_BF16) // simsimd_l2sq_bf16_accurate +SIMSIMD_MAKE_COS(accurate, bf16, f64, SIMSIMD_UNCOMPRESS_BF16) // simsimd_cos_bf16_accurate + +SIMSIMD_MAKE_L2SQ(accurate, i8, i32, SIMSIMD_IDENTIFY) // simsimd_l2sq_i8_accurate +SIMSIMD_MAKE_COS(accurate, i8, i32, SIMSIMD_IDENTIFY) // simsimd_cos_i8_accurate + +#if SIMSIMD_TARGET_ARM +#if SIMSIMD_TARGET_NEON +#pragma GCC push_options +#pragma GCC target("+simd") +#pragma clang attribute push(__attribute__((target("+simd"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_l2sq_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t sum_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vld1q_f32(a + i); + float32x4_t b_vec = vld1q_f32(b + i); + float32x4_t diff_vec = vsubq_f32(a_vec, b_vec); + sum_vec = vfmaq_f32(sum_vec, diff_vec, diff_vec); + } + simsimd_f32_t sum = vaddvq_f32(sum_vec); + for (; i < n; ++i) { + simsimd_f32_t diff = a[i] - b[i]; + sum += diff * diff; + } + *result = sum; +} + +SIMSIMD_PUBLIC void simsimd_cos_f32_neon(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t ab_vec = vdupq_n_f32(0), a2_vec = vdupq_n_f32(0), b2_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vld1q_f32(a + i); + float32x4_t b_vec = vld1q_f32(b + i); + ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec); + a2_vec = vfmaq_f32(a2_vec, a_vec, a_vec); + b2_vec = vfmaq_f32(b2_vec, b_vec, b_vec); + } + simsimd_f32_t ab = vaddvq_f32(ab_vec), a2 = vaddvq_f32(a2_vec), b2 = vaddvq_f32(b2_vec); + for (; i < n; ++i) { + simsimd_f32_t ai = a[i], bi = b[i]; + ab += ai * bi, a2 += ai * ai, b2 += bi * bi; + } + + // Avoid `simsimd_approximate_inverse_square_root` on Arm NEON + simsimd_f32_t a2_b2_arr[2] = {a2, b2}; + vst1_f32(a2_b2_arr, vrsqrte_f32(vld1_f32(a2_b2_arr))); + *result = ab != 0 ? 1 - ab * a2_b2_arr[0] * a2_b2_arr[1] : 1; +} + +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("+simd+fp16") +#pragma clang attribute push(__attribute__((target("+simd+fp16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_l2sq_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t sum_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a + i)); + float32x4_t b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b + i)); + float32x4_t diff_vec = vsubq_f32(a_vec, b_vec); + sum_vec = vfmaq_f32(sum_vec, diff_vec, diff_vec); + } + + // In case the software emulation for `f16` scalars is enabled, the `simsimd_uncompress_f16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + float16x4_t f16_vec; + simsimd_f16_t f16[4]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.f16[j] = a[i], b_padded_tail.f16[j] = b[i]; + for (; j < 4; ++j) + a_padded_tail.f16[j] = 0, b_padded_tail.f16[j] = 0; + float32x4_t diff_vec = vsubq_f32(vcvt_f32_f16(a_padded_tail.f16_vec), vcvt_f32_f16(b_padded_tail.f16_vec)); + sum_vec = vfmaq_f32(sum_vec, diff_vec, diff_vec); + } + + simsimd_f32_t sum = vaddvq_f32(sum_vec); + *result = sum; +} + +SIMSIMD_PUBLIC void simsimd_cos_f16_neon(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t ab_vec = vdupq_n_f32(0), a2_vec = vdupq_n_f32(0), b2_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 4 <= n; i += 4) { + float32x4_t a_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)a + i)); + float32x4_t b_vec = vcvt_f32_f16(vld1_f16((simsimd_f16_for_arm_simd_t const*)b + i)); + ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec); + a2_vec = vfmaq_f32(a2_vec, a_vec, a_vec); + b2_vec = vfmaq_f32(b2_vec, b_vec, b_vec); + } + + // In case the software emulation for `f16` scalars is enabled, the `simsimd_uncompress_f16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + float16x4_t f16_vec; + simsimd_f16_t f16[4]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.f16[j] = a[i], b_padded_tail.f16[j] = b[i]; + for (; j < 4; ++j) + a_padded_tail.f16[j] = 0, b_padded_tail.f16[j] = 0; + float32x4_t a_vec = vcvt_f32_f16(a_padded_tail.f16_vec); + float32x4_t b_vec = vcvt_f32_f16(b_padded_tail.f16_vec); + ab_vec = vfmaq_f32(ab_vec, a_vec, b_vec); + a2_vec = vfmaq_f32(a2_vec, a_vec, a_vec); + b2_vec = vfmaq_f32(b2_vec, b_vec, b_vec); + } + + // Avoid `simsimd_approximate_inverse_square_root` on Arm NEON + simsimd_f32_t ab = vaddvq_f32(ab_vec), a2 = vaddvq_f32(a2_vec), b2 = vaddvq_f32(b2_vec); + simsimd_f32_t a2_b2_arr[2] = {a2, b2}; + float32x2_t a2_b2 = vld1_f32(a2_b2_arr); + a2_b2 = vrsqrte_f32(a2_b2); + vst1_f32(a2_b2_arr, a2_b2); + *result = ab != 0 ? 1 - ab * a2_b2_arr[0] * a2_b2_arr[1] : 1; +} + +#if SIMSIMD_TARGET_NEON_BF16_IMPLEMENTED +SIMSIMD_PUBLIC void simsimd_cos_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t ab_high_vec = vdupq_n_f32(0), ab_low_vec = vdupq_n_f32(0); + float32x4_t a2_high_vec = vdupq_n_f32(0), a2_low_vec = vdupq_n_f32(0); + float32x4_t b2_high_vec = vdupq_n_f32(0), b2_low_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + bfloat16x8_t a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)a + i); + bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)b + i); + ab_high_vec = vbfmlaltq_f32(ab_high_vec, a_vec, b_vec); + ab_low_vec = vbfmlalbq_f32(ab_low_vec, a_vec, b_vec); + a2_high_vec = vbfmlaltq_f32(a2_high_vec, a_vec, a_vec); + a2_low_vec = vbfmlalbq_f32(a2_low_vec, a_vec, a_vec); + b2_high_vec = vbfmlaltq_f32(b2_high_vec, b_vec, b_vec); + b2_low_vec = vbfmlalbq_f32(b2_low_vec, b_vec, b_vec); + } + + // In case the software emulation for `bf16` scalars is enabled, the `simsimd_uncompress_bf16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + bfloat16x8_t bf16_vec; + simsimd_bf16_t bf16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.bf16[j] = a[i], b_padded_tail.bf16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.bf16[j] = 0, b_padded_tail.bf16[j] = 0; + ab_high_vec = vbfmlaltq_f32(ab_high_vec, a_padded_tail.bf16_vec, b_padded_tail.bf16_vec); + ab_low_vec = vbfmlalbq_f32(ab_low_vec, a_padded_tail.bf16_vec, b_padded_tail.bf16_vec); + a2_high_vec = vbfmlaltq_f32(a2_high_vec, a_padded_tail.bf16_vec, a_padded_tail.bf16_vec); + a2_low_vec = vbfmlalbq_f32(a2_low_vec, a_padded_tail.bf16_vec, a_padded_tail.bf16_vec); + b2_high_vec = vbfmlaltq_f32(b2_high_vec, b_padded_tail.bf16_vec, b_padded_tail.bf16_vec); + b2_low_vec = vbfmlalbq_f32(b2_low_vec, b_padded_tail.bf16_vec, b_padded_tail.bf16_vec); + } + + // Avoid `simsimd_approximate_inverse_square_root` on Arm NEON + simsimd_f32_t ab = vaddvq_f32(vaddq_f32(ab_high_vec, ab_low_vec)), + a2 = vaddvq_f32(vaddq_f32(a2_high_vec, a2_low_vec)), + b2 = vaddvq_f32(vaddq_f32(b2_high_vec, b2_low_vec)); + simsimd_f32_t a2_b2_arr[2] = {a2, b2}; + float32x2_t a2_b2 = vld1_f32(a2_b2_arr); + a2_b2 = vrsqrte_f32(a2_b2); + vst1_f32(a2_b2_arr, a2_b2); + *result = ab != 0 ? 1 - ab * a2_b2_arr[0] * a2_b2_arr[1] : 1; +} + +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_neon(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + float32x4_t diff_high_vec = vdupq_n_f32(0), diff_low_vec = vdupq_n_f32(0); + float32x4_t sum_high_vec = vdupq_n_f32(0), sum_low_vec = vdupq_n_f32(0); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + bfloat16x8_t a_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)a + i); + bfloat16x8_t b_vec = vld1q_bf16((simsimd_bf16_for_arm_simd_t const*)b + i); + // We can't perform subtraction in `bf16`. One option would be to upcast to `f32` + // and then subtract, converting back to `bf16` for computing the squared difference. + diff_high_vec = vsubq_f32(vcvt_f32_bf16(vget_high_bf16(a_vec)), vcvt_f32_bf16(vget_high_bf16(b_vec))); + diff_low_vec = vsubq_f32(vcvt_f32_bf16(vget_low_bf16(a_vec)), vcvt_f32_bf16(vget_low_bf16(b_vec))); + sum_high_vec = vfmaq_f32(sum_high_vec, diff_high_vec, diff_high_vec); + sum_low_vec = vfmaq_f32(sum_low_vec, diff_low_vec, diff_low_vec); + } + + // In case the software emulation for `bf16` scalars is enabled, the `simsimd_uncompress_bf16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + bfloat16x8_t bf16_vec; + simsimd_bf16_t bf16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.bf16[j] = a[i], b_padded_tail.bf16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.bf16[j] = 0, b_padded_tail.bf16[j] = 0; + diff_high_vec = vsubq_f32(vcvt_f32_bf16(vget_high_bf16(a_padded_tail.bf16_vec)), + vcvt_f32_bf16(vget_high_bf16(b_padded_tail.bf16_vec))); + diff_low_vec = vsubq_f32(vcvt_f32_bf16(vget_low_bf16(a_padded_tail.bf16_vec)), + vcvt_f32_bf16(vget_low_bf16(b_padded_tail.bf16_vec))); + sum_high_vec = vfmaq_f32(sum_high_vec, diff_high_vec, diff_high_vec); + sum_low_vec = vfmaq_f32(sum_low_vec, diff_low_vec, diff_low_vec); + } + + // Avoid `simsimd_approximate_inverse_square_root` on Arm NEON + simsimd_f32_t sum = vaddvq_f32(vaddq_f32(sum_high_vec, sum_low_vec)); + *result = sum; +} +#endif +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+dotprod") +#pragma clang attribute push(__attribute__((target("arch=armv8.2-a+dotprod"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_l2sq_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + int32x4_t d2_vec = vdupq_n_s32(0); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + int8x8_t a_vec = vld1_s8(a + i); + int8x8_t b_vec = vld1_s8(b + i); + int16x8_t a_vec16 = vmovl_s8(a_vec); + int16x8_t b_vec16 = vmovl_s8(b_vec); + int16x8_t d_vec = vsubq_s16(a_vec16, b_vec16); + int32x4_t d_low = vmull_s16(vget_low_s16(d_vec), vget_low_s16(d_vec)); + int32x4_t d_high = vmull_s16(vget_high_s16(d_vec), vget_high_s16(d_vec)); + d2_vec = vaddq_s32(d2_vec, vaddq_s32(d_low, d_high)); + } + int32_t d2 = vaddvq_s32(d2_vec); + for (; i < n; ++i) { + int32_t n = a[i] - b[i]; + d2 += n * n; + } + *result = d2; +} + +SIMSIMD_PUBLIC void simsimd_cos_i8_neon(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + int32x4_t ab_vec = vdupq_n_s32(0); + int32x4_t a2_vec = vdupq_n_s32(0); + int32x4_t b2_vec = vdupq_n_s32(0); + simsimd_size_t i = 0; + + // If the 128-bit `vdot_s32` intrinsic is unavailable, we can use the 64-bit `vdot_s32`. + // for (simsimd_size_t i = 0; i != n; i += 8) { + // int16x8_t a_vec = vmovl_s8(vld1_s8(a + i)); + // int16x8_t b_vec = vmovl_s8(vld1_s8(b + i)); + // int16x8_t ab_part_vec = vmulq_s16(a_vec, b_vec); + // int16x8_t a2_part_vec = vmulq_s16(a_vec, a_vec); + // int16x8_t b2_part_vec = vmulq_s16(b_vec, b_vec); + // ab_vec = vaddq_s32(ab_vec, vaddq_s32(vmovl_s16(vget_high_s16(ab_part_vec)), // + // vmovl_s16(vget_low_s16(ab_part_vec)))); + // a2_vec = vaddq_s32(a2_vec, vaddq_s32(vmovl_s16(vget_high_s16(a2_part_vec)), // + // vmovl_s16(vget_low_s16(a2_part_vec)))); + // b2_vec = vaddq_s32(b2_vec, vaddq_s32(vmovl_s16(vget_high_s16(b2_part_vec)), // + // vmovl_s16(vget_low_s16(b2_part_vec)))); + // } + for (; i + 16 <= n; i += 16) { + int8x16_t a_vec = vld1q_s8(a + i); + int8x16_t b_vec = vld1q_s8(b + i); + ab_vec = vdotq_s32(ab_vec, a_vec, b_vec); + a2_vec = vdotq_s32(a2_vec, a_vec, a_vec); + b2_vec = vdotq_s32(b2_vec, b_vec, b_vec); + } + + int32_t ab = vaddvq_s32(ab_vec); + int32_t a2 = vaddvq_s32(a2_vec); + int32_t b2 = vaddvq_s32(b2_vec); + + // Take care of the tail: + for (; i < n; ++i) { + int32_t ai = a[i], bi = b[i]; + ab += ai * bi, a2 += ai * ai, b2 += bi * bi; + } + + // Avoid `simsimd_approximate_inverse_square_root` on Arm NEON + simsimd_f32_t a2_b2_arr[2] = {(simsimd_f32_t)a2, (simsimd_f32_t)b2}; + float32x2_t a2_b2 = vld1_f32(a2_b2_arr); + a2_b2 = vrsqrte_f32(a2_b2); + vst1_f32(a2_b2_arr, a2_b2); + *result = ab != 0 ? 1 - ab * a2_b2_arr[0] * a2_b2_arr[1] : 1; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_NEON + +#if SIMSIMD_TARGET_SVE +#pragma GCC push_options +#pragma GCC target("+sve") +#pragma clang attribute push(__attribute__((target("+sve"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_l2sq_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + svfloat32_t d2_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); + do { + svbool_t pg_vec = svwhilelt_b32((unsigned int)i, (unsigned int)n); + svfloat32_t a_vec = svld1_f32(pg_vec, a + i); + svfloat32_t b_vec = svld1_f32(pg_vec, b + i); + svfloat32_t a_minus_b_vec = svsub_f32_x(pg_vec, a_vec, b_vec); + d2_vec = svmla_f32_x(pg_vec, d2_vec, a_minus_b_vec, a_minus_b_vec); + i += svcntw(); + } while (i < n); + simsimd_f32_t d2 = svaddv_f32(svptrue_b32(), d2_vec); + *result = d2; +} + +SIMSIMD_PUBLIC void simsimd_cos_f32_sve(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + svfloat32_t ab_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); + svfloat32_t a2_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); + svfloat32_t b2_vec = svdupq_n_f32(0.f, 0.f, 0.f, 0.f); + do { + svbool_t pg_vec = svwhilelt_b32((unsigned int)i, (unsigned int)n); + svfloat32_t a_vec = svld1_f32(pg_vec, a + i); + svfloat32_t b_vec = svld1_f32(pg_vec, b + i); + ab_vec = svmla_f32_x(pg_vec, ab_vec, a_vec, b_vec); + a2_vec = svmla_f32_x(pg_vec, a2_vec, a_vec, a_vec); + b2_vec = svmla_f32_x(pg_vec, b2_vec, b_vec, b_vec); + i += svcntw(); + } while (i < n); + + simsimd_f32_t ab = svaddv_f32(svptrue_b32(), ab_vec); + simsimd_f32_t a2 = svaddv_f32(svptrue_b32(), a2_vec); + simsimd_f32_t b2 = svaddv_f32(svptrue_b32(), b2_vec); + + // Avoid `simsimd_approximate_inverse_square_root` on Arm NEON + simsimd_f32_t a2_b2_arr[2] = {a2, b2}; + float32x2_t a2_b2 = vld1_f32(a2_b2_arr); + a2_b2 = vrsqrte_f32(a2_b2); + vst1_f32(a2_b2_arr, a2_b2); + *result = ab != 0 ? 1 - ab * a2_b2_arr[0] * a2_b2_arr[1] : 1; +} + +SIMSIMD_PUBLIC void simsimd_l2sq_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + svfloat64_t d2_vec = svdupq_n_f64(0.0, 0.0); + do { + svbool_t pg_vec = svwhilelt_b64((unsigned int)i, (unsigned int)n); + svfloat64_t a_vec = svld1_f64(pg_vec, a + i); + svfloat64_t b_vec = svld1_f64(pg_vec, b + i); + svfloat64_t a_minus_b_vec = svsub_f64_x(pg_vec, a_vec, b_vec); + d2_vec = svmla_f64_x(pg_vec, d2_vec, a_minus_b_vec, a_minus_b_vec); + i += svcntd(); + } while (i < n); + simsimd_f64_t d2 = svaddv_f64(svptrue_b32(), d2_vec); + *result = d2; +} + +SIMSIMD_PUBLIC void simsimd_cos_f64_sve(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + svfloat64_t ab_vec = svdupq_n_f64(0.0, 0.0); + svfloat64_t a2_vec = svdupq_n_f64(0.0, 0.0); + svfloat64_t b2_vec = svdupq_n_f64(0.0, 0.0); + do { + svbool_t pg_vec = svwhilelt_b64((unsigned int)i, (unsigned int)n); + svfloat64_t a_vec = svld1_f64(pg_vec, a + i); + svfloat64_t b_vec = svld1_f64(pg_vec, b + i); + ab_vec = svmla_f64_x(pg_vec, ab_vec, a_vec, b_vec); + a2_vec = svmla_f64_x(pg_vec, a2_vec, a_vec, a_vec); + b2_vec = svmla_f64_x(pg_vec, b2_vec, b_vec, b_vec); + i += svcntd(); + } while (i < n); + + simsimd_f64_t ab = svaddv_f64(svptrue_b32(), ab_vec); + simsimd_f64_t a2 = svaddv_f64(svptrue_b32(), a2_vec); + simsimd_f64_t b2 = svaddv_f64(svptrue_b32(), b2_vec); + + // Avoid `simsimd_approximate_inverse_square_root` on Arm NEON + simsimd_f64_t a2_b2_arr[2] = {a2, b2}; + float64x2_t a2_b2 = vld1q_f64(a2_b2_arr); + a2_b2 = vrsqrteq_f64(a2_b2); + vst1q_f64(a2_b2_arr, a2_b2); + *result = ab != 0 ? 1 - ab * a2_b2_arr[0] * a2_b2_arr[1] : 1; +} + +#pragma clang attribute pop +#pragma GCC pop_options + +#pragma GCC push_options +#pragma GCC target("+sve+fp16") +#pragma clang attribute push(__attribute__((target("+sve+fp16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_l2sq_f16_sve(simsimd_f16_t const* a_enum, simsimd_f16_t const* b_enum, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + svfloat16_t d2_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); + simsimd_f16_for_arm_simd_t const* a = (simsimd_f16_for_arm_simd_t const*)(a_enum); + simsimd_f16_for_arm_simd_t const* b = (simsimd_f16_for_arm_simd_t const*)(b_enum); + do { + svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); + svfloat16_t a_vec = svld1_f16(pg_vec, a + i); + svfloat16_t b_vec = svld1_f16(pg_vec, b + i); + svfloat16_t a_minus_b_vec = svsub_f16_x(pg_vec, a_vec, b_vec); + d2_vec = svmla_f16_x(pg_vec, d2_vec, a_minus_b_vec, a_minus_b_vec); + i += svcnth(); + } while (i < n); + simsimd_f16_for_arm_simd_t d2_f16 = svaddv_f16(svptrue_b16(), d2_vec); + *result = d2_f16; +} + +SIMSIMD_PUBLIC void simsimd_cos_f16_sve(simsimd_f16_t const* a_enum, simsimd_f16_t const* b_enum, simsimd_size_t n, + simsimd_distance_t* result) { + simsimd_size_t i = 0; + svfloat16_t ab_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); + svfloat16_t a2_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); + svfloat16_t b2_vec = svdupq_n_f16(0, 0, 0, 0, 0, 0, 0, 0); + simsimd_f16_for_arm_simd_t const* a = (simsimd_f16_for_arm_simd_t const*)(a_enum); + simsimd_f16_for_arm_simd_t const* b = (simsimd_f16_for_arm_simd_t const*)(b_enum); + do { + svbool_t pg_vec = svwhilelt_b16((unsigned int)i, (unsigned int)n); + svfloat16_t a_vec = svld1_f16(pg_vec, a + i); + svfloat16_t b_vec = svld1_f16(pg_vec, b + i); + ab_vec = svmla_f16_x(pg_vec, ab_vec, a_vec, b_vec); + a2_vec = svmla_f16_x(pg_vec, a2_vec, a_vec, a_vec); + b2_vec = svmla_f16_x(pg_vec, b2_vec, b_vec, b_vec); + i += svcnth(); + } while (i < n); + + simsimd_f16_for_arm_simd_t ab = svaddv_f16(svptrue_b16(), ab_vec); + simsimd_f16_for_arm_simd_t a2 = svaddv_f16(svptrue_b16(), a2_vec); + simsimd_f16_for_arm_simd_t b2 = svaddv_f16(svptrue_b16(), b2_vec); + + // Avoid `simsimd_approximate_inverse_square_root` on Arm NEON + simsimd_f32_t a2_b2_arr[2] = {a2, b2}; + float32x2_t a2_b2 = vld1_f32(a2_b2_arr); + a2_b2 = vrsqrte_f32(a2_b2); + vst1_f32(a2_b2_arr, a2_b2); + *result = ab != 0 ? 1 - ab * a2_b2_arr[0] * a2_b2_arr[1] : 1; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SVE +#endif // SIMSIMD_TARGET_ARM + +#if SIMSIMD_TARGET_X86 +#if SIMSIMD_TARGET_HASWELL +#pragma GCC push_options +#pragma GCC target("avx2", "f16c", "fma") +#pragma clang attribute push(__attribute__((target("avx2,f16c,fma"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_l2sq_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m256 d2_vec = _mm256_setzero_ps(); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i))); + __m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i))); + __m256 d_vec = _mm256_sub_ps(a_vec, b_vec); + d2_vec = _mm256_fmadd_ps(d_vec, d_vec, d2_vec); + } + + // In case the software emulation for `f16` scalars is enabled, the `simsimd_uncompress_f16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + __m128i f16_vec; + simsimd_f16_t f16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.f16[j] = a[i], b_padded_tail.f16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.f16[j] = 0, b_padded_tail.f16[j] = 0; + __m256 a_vec = _mm256_cvtph_ps(a_padded_tail.f16_vec); + __m256 b_vec = _mm256_cvtph_ps(b_padded_tail.f16_vec); + __m256 d_vec = _mm256_sub_ps(a_vec, b_vec); + d2_vec = _mm256_fmadd_ps(d_vec, d_vec, d2_vec); + } + + d2_vec = _mm256_add_ps(_mm256_permute2f128_ps(d2_vec, d2_vec, 1), d2_vec); + d2_vec = _mm256_hadd_ps(d2_vec, d2_vec); + d2_vec = _mm256_hadd_ps(d2_vec, d2_vec); + + simsimd_f32_t f32_result; + _mm_store_ss(&f32_result, _mm256_castps256_ps128(d2_vec)); + *result = f32_result; +} + +SIMSIMD_PUBLIC void simsimd_cos_f16_haswell(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + __m256 ab_vec = _mm256_setzero_ps(), a2_vec = _mm256_setzero_ps(), b2_vec = _mm256_setzero_ps(); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + __m256 a_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(a + i))); + __m256 b_vec = _mm256_cvtph_ps(_mm_loadu_si128((__m128i const*)(b + i))); + ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); + a2_vec = _mm256_fmadd_ps(a_vec, a_vec, a2_vec); + b2_vec = _mm256_fmadd_ps(b_vec, b_vec, b2_vec); + } + + // In case the software emulation for `f16` scalars is enabled, the `simsimd_uncompress_f16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + __m128i f16_vec; + simsimd_f16_t f16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.f16[j] = a[i], b_padded_tail.f16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.f16[j] = 0, b_padded_tail.f16[j] = 0; + __m256 a_vec = _mm256_cvtph_ps(a_padded_tail.f16_vec); + __m256 b_vec = _mm256_cvtph_ps(b_padded_tail.f16_vec); + ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); + a2_vec = _mm256_fmadd_ps(a_vec, a_vec, a2_vec); + b2_vec = _mm256_fmadd_ps(b_vec, b_vec, b2_vec); + } + + // Horizontal reductions: + ab_vec = _mm256_add_ps(_mm256_permute2f128_ps(ab_vec, ab_vec, 1), ab_vec); + ab_vec = _mm256_hadd_ps(ab_vec, ab_vec); + ab_vec = _mm256_hadd_ps(ab_vec, ab_vec); + + a2_vec = _mm256_add_ps(_mm256_permute2f128_ps(a2_vec, a2_vec, 1), a2_vec); + a2_vec = _mm256_hadd_ps(a2_vec, a2_vec); + a2_vec = _mm256_hadd_ps(a2_vec, a2_vec); + + b2_vec = _mm256_add_ps(_mm256_permute2f128_ps(b2_vec, b2_vec, 1), b2_vec); + b2_vec = _mm256_hadd_ps(b2_vec, b2_vec); + b2_vec = _mm256_hadd_ps(b2_vec, b2_vec); + + simsimd_f32_t ab, a2, b2; + _mm_store_ss(&ab, _mm256_castps256_ps128(ab_vec)); + _mm_store_ss(&a2, _mm256_castps256_ps128(a2_vec)); + _mm_store_ss(&b2, _mm256_castps256_ps128(b2_vec)); + + // Replace simsimd_approximate_inverse_square_root with `rsqrtss` + __m128 a2_sqrt_recip = _mm_rsqrt_ss(_mm_set_ss((float)a2)); + __m128 b2_sqrt_recip = _mm_rsqrt_ss(_mm_set_ss((float)b2)); + __m128 result_vec = _mm_mul_ss(a2_sqrt_recip, b2_sqrt_recip); // Multiply the reciprocal square roots + result_vec = _mm_mul_ss(result_vec, _mm_set_ss((float)ab)); // Multiply by ab + result_vec = _mm_sub_ss(_mm_set_ss(1.0f), result_vec); // Subtract from 1 + *result = ab != 0 ? _mm_cvtss_f32(result_vec) : 1; // Extract the final result +} + +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m256 d2_vec = _mm256_setzero_ps(); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + // Upcasting from `bf16` to `f32` is done by shifting the `bf16` values by 16 bits to the left, like: + // x = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(x), 16)) + __m256 a_vec = + _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i const*)(a + i))), 16)); + __m256 b_vec = + _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i const*)(b + i))), 16)); + __m256 d_vec = _mm256_sub_ps(a_vec, b_vec); + d2_vec = _mm256_fmadd_ps(d_vec, d_vec, d2_vec); + } + + // In case the software emulation for `bf16` scalars is enabled, the `simsimd_uncompress_bf16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + __m128i bf16_vec; + simsimd_bf16_t bf16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.bf16[j] = a[i], b_padded_tail.bf16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.bf16[j] = 0, b_padded_tail.bf16[j] = 0; + __m256 a_vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a_padded_tail.bf16_vec), 16)); + __m256 b_vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(b_padded_tail.bf16_vec), 16)); + __m256 d_vec = _mm256_sub_ps(a_vec, b_vec); + d2_vec = _mm256_fmadd_ps(d_vec, d_vec, d2_vec); + } + + d2_vec = _mm256_add_ps(_mm256_permute2f128_ps(d2_vec, d2_vec, 1), d2_vec); + d2_vec = _mm256_hadd_ps(d2_vec, d2_vec); + d2_vec = _mm256_hadd_ps(d2_vec, d2_vec); + + simsimd_f32_t f32_result; + _mm_store_ss(&f32_result, _mm256_castps256_ps128(d2_vec)); + *result = f32_result; +} + +SIMSIMD_PUBLIC void simsimd_cos_bf16_haswell(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + __m256 ab_vec = _mm256_setzero_ps(), a2_vec = _mm256_setzero_ps(), b2_vec = _mm256_setzero_ps(); + simsimd_size_t i = 0; + for (; i + 8 <= n; i += 8) { + // Upcasting from `bf16` to `f32` is done by shifting the `bf16` values by 16 bits to the left, like: + // x = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(x), 16)) + __m256 a_vec = + _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i const*)(a + i))), 16)); + __m256 b_vec = + _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((__m128i const*)(b + i))), 16)); + ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); + a2_vec = _mm256_fmadd_ps(a_vec, a_vec, a2_vec); + b2_vec = _mm256_fmadd_ps(b_vec, b_vec, b2_vec); + } + + // In case the software emulation for `bf16` scalars is enabled, the `simsimd_uncompress_bf16` + // function will run. It is extremely slow, so even for the tail, let's combine serial + // loads and stores with vectorized math. + if (i < n) { + union { + __m128i bf16_vec; + simsimd_bf16_t bf16[8]; + } a_padded_tail, b_padded_tail; + simsimd_size_t j = 0; + for (; i < n; ++i, ++j) + a_padded_tail.bf16[j] = a[i], b_padded_tail.bf16[j] = b[i]; + for (; j < 8; ++j) + a_padded_tail.bf16[j] = 0, b_padded_tail.bf16[j] = 0; + __m256 a_vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a_padded_tail.bf16_vec), 16)); + __m256 b_vec = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(b_padded_tail.bf16_vec), 16)); + ab_vec = _mm256_fmadd_ps(a_vec, b_vec, ab_vec); + a2_vec = _mm256_fmadd_ps(a_vec, a_vec, a2_vec); + b2_vec = _mm256_fmadd_ps(b_vec, b_vec, b2_vec); + } + + // Horizontal reductions: + ab_vec = _mm256_add_ps(_mm256_permute2f128_ps(ab_vec, ab_vec, 1), ab_vec); + ab_vec = _mm256_hadd_ps(ab_vec, ab_vec); + ab_vec = _mm256_hadd_ps(ab_vec, ab_vec); + + a2_vec = _mm256_add_ps(_mm256_permute2f128_ps(a2_vec, a2_vec, 1), a2_vec); + a2_vec = _mm256_hadd_ps(a2_vec, a2_vec); + a2_vec = _mm256_hadd_ps(a2_vec, a2_vec); + + b2_vec = _mm256_add_ps(_mm256_permute2f128_ps(b2_vec, b2_vec, 1), b2_vec); + b2_vec = _mm256_hadd_ps(b2_vec, b2_vec); + b2_vec = _mm256_hadd_ps(b2_vec, b2_vec); + + simsimd_f32_t ab, a2, b2; + _mm_store_ss(&ab, _mm256_castps256_ps128(ab_vec)); + _mm_store_ss(&a2, _mm256_castps256_ps128(a2_vec)); + _mm_store_ss(&b2, _mm256_castps256_ps128(b2_vec)); + + // Replace simsimd_approximate_inverse_square_root with `rsqrtss` + __m128 a2_sqrt_recip = _mm_rsqrt_ss(_mm_set_ss((float)a2)); + __m128 b2_sqrt_recip = _mm_rsqrt_ss(_mm_set_ss((float)b2)); + __m128 result_vec = _mm_mul_ss(a2_sqrt_recip, b2_sqrt_recip); // Multiply the reciprocal square roots + result_vec = _mm_mul_ss(result_vec, _mm_set_ss((float)ab)); // Multiply by ab + result_vec = _mm_sub_ss(_mm_set_ss(1.0f), result_vec); // Subtract from 1 + *result = ab != 0 ? _mm_cvtss_f32(result_vec) : 1; // Extract the final result +} + +SIMSIMD_PUBLIC void simsimd_l2sq_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + __m256i d2_high_vec = _mm256_setzero_si256(); + __m256i d2_low_vec = _mm256_setzero_si256(); + + simsimd_size_t i = 0; + for (; i + 32 <= n; i += 32) { + __m256i a_vec = _mm256_loadu_si256((__m256i const*)(a + i)); + __m256i b_vec = _mm256_loadu_si256((__m256i const*)(b + i)); + + // Sign extend int8 to int16 + __m256i a_low = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(a_vec)); + __m256i a_high = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_vec, 1)); + __m256i b_low = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(b_vec)); + __m256i b_high = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_vec, 1)); + + // Subtract and multiply + __m256i d_low = _mm256_sub_epi16(a_low, b_low); + __m256i d_high = _mm256_sub_epi16(a_high, b_high); + __m256i d2_low_part = _mm256_madd_epi16(d_low, d_low); + __m256i d2_high_part = _mm256_madd_epi16(d_high, d_high); + + // Accumulate into int32 vectors + d2_low_vec = _mm256_add_epi32(d2_low_vec, d2_low_part); + d2_high_vec = _mm256_add_epi32(d2_high_vec, d2_high_part); + } + + // Accumulate the 32-bit integers from `d2_high_vec` and `d2_low_vec` + __m256i d2_vec = _mm256_add_epi32(d2_low_vec, d2_high_vec); + __m128i d2_sum = _mm_add_epi32(_mm256_extracti128_si256(d2_vec, 0), _mm256_extracti128_si256(d2_vec, 1)); + d2_sum = _mm_hadd_epi32(d2_sum, d2_sum); + d2_sum = _mm_hadd_epi32(d2_sum, d2_sum); + int d2 = _mm_extract_epi32(d2_sum, 0); + + // Take care of the tail: + for (; i < n; ++i) { + int n = a[i] - b[i]; + d2 += n * n; + } + + *result = (simsimd_f64_t)d2; +} + +SIMSIMD_PUBLIC void simsimd_cos_i8_haswell(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + + __m256i ab_low_vec = _mm256_setzero_si256(); + __m256i ab_high_vec = _mm256_setzero_si256(); + __m256i a2_low_vec = _mm256_setzero_si256(); + __m256i a2_high_vec = _mm256_setzero_si256(); + __m256i b2_low_vec = _mm256_setzero_si256(); + __m256i b2_high_vec = _mm256_setzero_si256(); + + simsimd_size_t i = 0; + for (; i + 32 <= n; i += 32) { + __m256i a_vec = _mm256_loadu_si256((__m256i const*)(a + i)); + __m256i b_vec = _mm256_loadu_si256((__m256i const*)(b + i)); + + // Unpack `int8` to `int16` + __m256i a_low_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_vec, 0)); + __m256i a_high_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(a_vec, 1)); + __m256i b_low_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_vec, 0)); + __m256i b_high_16 = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(b_vec, 1)); + + // Multiply and accumulate as `int16`, accumulate products as `int32` + ab_low_vec = _mm256_add_epi32(ab_low_vec, _mm256_madd_epi16(a_low_16, b_low_16)); + ab_high_vec = _mm256_add_epi32(ab_high_vec, _mm256_madd_epi16(a_high_16, b_high_16)); + a2_low_vec = _mm256_add_epi32(a2_low_vec, _mm256_madd_epi16(a_low_16, a_low_16)); + a2_high_vec = _mm256_add_epi32(a2_high_vec, _mm256_madd_epi16(a_high_16, a_high_16)); + b2_low_vec = _mm256_add_epi32(b2_low_vec, _mm256_madd_epi16(b_low_16, b_low_16)); + b2_high_vec = _mm256_add_epi32(b2_high_vec, _mm256_madd_epi16(b_high_16, b_high_16)); + } + + // Horizontal sum across the 256-bit register + __m256i ab_vec = _mm256_add_epi32(ab_low_vec, ab_high_vec); + __m128i ab_sum = _mm_add_epi32(_mm256_extracti128_si256(ab_vec, 0), _mm256_extracti128_si256(ab_vec, 1)); + ab_sum = _mm_hadd_epi32(ab_sum, ab_sum); + ab_sum = _mm_hadd_epi32(ab_sum, ab_sum); + + __m256i a2_vec = _mm256_add_epi32(a2_low_vec, a2_high_vec); + __m128i a2_sum = _mm_add_epi32(_mm256_extracti128_si256(a2_vec, 0), _mm256_extracti128_si256(a2_vec, 1)); + a2_sum = _mm_hadd_epi32(a2_sum, a2_sum); + a2_sum = _mm_hadd_epi32(a2_sum, a2_sum); + + __m256i b2_vec = _mm256_add_epi32(b2_low_vec, b2_high_vec); + __m128i b2_sum = _mm_add_epi32(_mm256_extracti128_si256(b2_vec, 0), _mm256_extracti128_si256(b2_vec, 1)); + b2_sum = _mm_hadd_epi32(b2_sum, b2_sum); + b2_sum = _mm_hadd_epi32(b2_sum, b2_sum); + + // Further reduce to a single sum for each vector + int ab = _mm_extract_epi32(ab_sum, 0); + int a2 = _mm_extract_epi32(a2_sum, 0); + int b2 = _mm_extract_epi32(b2_sum, 0); + + // Take care of the tail: + for (; i < n; ++i) { + int ai = a[i], bi = b[i]; + ab += ai * bi, a2 += ai * ai, b2 += bi * bi; + } + + // Compute the reciprocal of the square roots + __m128 a2_sqrt_recip = _mm_rsqrt_ss(_mm_set_ss((float)a2)); + __m128 b2_sqrt_recip = _mm_rsqrt_ss(_mm_set_ss((float)b2)); + + // Compute cosine similarity: ab / sqrt(a2 * b2) + __m128 denom = _mm_mul_ss(a2_sqrt_recip, b2_sqrt_recip); // Reciprocal of sqrt(a2 * b2) + __m128 result_vec = _mm_mul_ss(_mm_set_ss((float)ab), denom); // ab * reciprocal of sqrt(a2 * b2) + *result = ab != 0 ? 1 - _mm_cvtss_f32(result_vec) : 0; // Extract the final result +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_HASWELL + +#if SIMSIMD_TARGET_SKYLAKE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_l2sq_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512 d2_vec = _mm512_setzero(); + __m512 a_vec, b_vec; + +simsimd_l2sq_f32_skylake_cycle: + if (n < 16) { + __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_ps(mask, a); + b_vec = _mm512_maskz_loadu_ps(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_ps(a); + b_vec = _mm512_loadu_ps(b); + a += 16, b += 16, n -= 16; + } + __m512 d_vec = _mm512_sub_ps(a_vec, b_vec); + d2_vec = _mm512_fmadd_ps(d_vec, d_vec, d2_vec); + if (n) + goto simsimd_l2sq_f32_skylake_cycle; + + *result = _mm512_reduce_add_ps(d2_vec); +} + +SIMSIMD_PUBLIC void simsimd_cos_f32_skylake(simsimd_f32_t const* a, simsimd_f32_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512 ab_vec = _mm512_setzero(); + __m512 a2_vec = _mm512_setzero(); + __m512 b2_vec = _mm512_setzero(); + __m512 a_vec, b_vec; + +simsimd_cos_f32_skylake_cycle: + if (n < 16) { + __mmask16 mask = (__mmask16)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_ps(mask, a); + b_vec = _mm512_maskz_loadu_ps(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_ps(a); + b_vec = _mm512_loadu_ps(b); + a += 16, b += 16, n -= 16; + } + ab_vec = _mm512_fmadd_ps(a_vec, b_vec, ab_vec); + a2_vec = _mm512_fmadd_ps(a_vec, a_vec, a2_vec); + b2_vec = _mm512_fmadd_ps(b_vec, b_vec, b2_vec); + if (n) + goto simsimd_cos_f32_skylake_cycle; + + simsimd_f32_t ab = _mm512_reduce_add_ps(ab_vec); + simsimd_f32_t a2 = _mm512_reduce_add_ps(a2_vec); + simsimd_f32_t b2 = _mm512_reduce_add_ps(b2_vec); + + // Compute the reciprocal square roots of a2 and b2 + // Mysteriously, MSVC has no `_mm_rsqrt14_ps` intrinsic, but has it's masked variants, + // so let's use `_mm_maskz_rsqrt14_ps(0xFF, ...)` instead. + __m128 rsqrts = _mm_maskz_rsqrt14_ps(0xFF, _mm_set_ps(0.f, 0.f, a2 + 1.e-9f, b2 + 1.e-9f)); + simsimd_f32_t rsqrt_a2 = _mm_cvtss_f32(rsqrts); + simsimd_f32_t rsqrt_b2 = _mm_cvtss_f32(_mm_shuffle_ps(rsqrts, rsqrts, _MM_SHUFFLE(0, 0, 0, 1))); + *result = 1 - ab * rsqrt_a2 * rsqrt_b2; +} + +SIMSIMD_PUBLIC void simsimd_l2sq_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512d d2_vec = _mm512_setzero_pd(); + __m512d a_vec, b_vec; + +simsimd_l2sq_f64_skylake_cycle: + if (n < 8) { + __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_pd(mask, a); + b_vec = _mm512_maskz_loadu_pd(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_pd(a); + b_vec = _mm512_loadu_pd(b); + a += 8, b += 8, n -= 8; + } + __m512d d_vec = _mm512_sub_pd(a_vec, b_vec); + d2_vec = _mm512_fmadd_pd(d_vec, d_vec, d2_vec); + if (n) + goto simsimd_l2sq_f64_skylake_cycle; + + *result = _mm512_reduce_add_pd(d2_vec); +} + +SIMSIMD_PUBLIC void simsimd_cos_f64_skylake(simsimd_f64_t const* a, simsimd_f64_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512d ab_vec = _mm512_setzero_pd(); + __m512d a2_vec = _mm512_setzero_pd(); + __m512d b2_vec = _mm512_setzero_pd(); + __m512d a_vec, b_vec; + +simsimd_cos_f64_skylake_cycle: + if (n < 8) { + __mmask8 mask = (__mmask8)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_pd(mask, a); + b_vec = _mm512_maskz_loadu_pd(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_pd(a); + b_vec = _mm512_loadu_pd(b); + a += 8, b += 8, n -= 8; + } + ab_vec = _mm512_fmadd_pd(a_vec, b_vec, ab_vec); + a2_vec = _mm512_fmadd_pd(a_vec, a_vec, a2_vec); + b2_vec = _mm512_fmadd_pd(b_vec, b_vec, b2_vec); + if (n) + goto simsimd_cos_f64_skylake_cycle; + + simsimd_f32_t ab = (simsimd_f32_t)_mm512_reduce_add_pd(ab_vec); + simsimd_f32_t a2 = (simsimd_f32_t)_mm512_reduce_add_pd(a2_vec); + simsimd_f32_t b2 = (simsimd_f32_t)_mm512_reduce_add_pd(b2_vec); + + // Compute the reciprocal square roots of a2 and b2 + // Mysteriously, MSVC has no `_mm_rsqrt14_ps` intrinsic, but has it's masked variants, + // so let's use `_mm_maskz_rsqrt14_ps(0xFF, ...)` instead. + __m128 rsqrts = _mm_maskz_rsqrt14_ps(0xFF, _mm_set_ps(0.f, 0.f, a2 + 1.e-9f, b2 + 1.e-9f)); + simsimd_f32_t rsqrt_a2 = _mm_cvtss_f32(rsqrts); + simsimd_f32_t rsqrt_b2 = _mm_cvtss_f32(_mm_shuffle_ps(rsqrts, rsqrts, _MM_SHUFFLE(0, 0, 0, 1))); + *result = 1 - ab * rsqrt_a2 * rsqrt_b2; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SKYLAKE + +#if SIMSIMD_TARGET_GENOA +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512bw", "avx512bf16") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512bw,avx512bf16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_l2sq_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512 d2_top_vec = _mm512_setzero_ps(), d2_bot_vec = _mm512_setzero_ps(); + __m512 d_top_vec = _mm512_setzero_ps(), d_bot_vec = _mm512_setzero_ps(); + __m512 a_f32_top_vec, a_f32_bot_vec, b_f32_top_vec, b_f32_bot_vec; + __m512i a_i16_vec, b_i16_vec; + +simsimd_l2sq_bf16_genoa_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); + b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); + n = 0; + } else { + a_i16_vec = _mm512_loadu_epi16(a); + b_i16_vec = _mm512_loadu_epi16(b); + a += 32, b += 32, n -= 32; + } + // Let's perform the subtraction with single-precision, while the dot-product with half-precision. + // For that we need to perform a couple of casts - each is a bitshift. To convert `bf16` to `f32`, + // expand it to 32-bit integers, then shift the bits by 16 to the left. Then subtract as floats, + // and shift back. During expansion, we will double the space, and should use separate registers + // for top and bottom halves. + a_f32_bot_vec = + _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(a_i16_vec)), 16)); + b_f32_bot_vec = + _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_castsi512_si256(b_i16_vec)), 16)); + + // Some compilers don't have `_mm512_extracti32x8_epi32`, so we need to use `_mm512_extracti64x4_epi64` + a_f32_top_vec = + _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(a_i16_vec, 1)), 16)); + b_f32_top_vec = + _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm512_extracti64x4_epi64(b_i16_vec, 1)), 16)); + + // Subtract and cast back + d_top_vec = _mm512_sub_ps(a_f32_top_vec, b_f32_top_vec); + d_bot_vec = _mm512_sub_ps(a_f32_bot_vec, b_f32_bot_vec); + d_top_vec = _mm512_castsi512_ps(_mm512_srli_epi32(_mm512_castps_si512(d_top_vec), 16)); + d_bot_vec = _mm512_castsi512_ps(_mm512_srli_epi32(_mm512_castps_si512(d_bot_vec), 16)); + + // Square and accumulate + d2_top_vec = _mm512_dpbf16_ps(d2_top_vec, (__m512bh)(d_top_vec), (__m512bh)(d_top_vec)); + d2_bot_vec = _mm512_dpbf16_ps(d2_bot_vec, (__m512bh)(d_bot_vec), (__m512bh)(d_bot_vec)); + if (n) + goto simsimd_l2sq_bf16_genoa_cycle; + + *result = _mm512_reduce_add_ps(d2_top_vec) + _mm512_reduce_add_ps(d2_bot_vec); +} + +SIMSIMD_PUBLIC void simsimd_cos_bf16_genoa(simsimd_bf16_t const* a, simsimd_bf16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512 ab_vec = _mm512_setzero_ps(); + __m512 a2_vec = _mm512_setzero_ps(); + __m512 b2_vec = _mm512_setzero_ps(); + __m512i a_i16_vec, b_i16_vec; + +simsimd_cos_bf16_genoa_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); + b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); + n = 0; + } else { + a_i16_vec = _mm512_loadu_epi16(a); + b_i16_vec = _mm512_loadu_epi16(b); + a += 32, b += 32, n -= 32; + } + ab_vec = _mm512_dpbf16_ps(ab_vec, (__m512bh)(a_i16_vec), (__m512bh)(b_i16_vec)); + a2_vec = _mm512_dpbf16_ps(a2_vec, (__m512bh)(a_i16_vec), (__m512bh)(a_i16_vec)); + b2_vec = _mm512_dpbf16_ps(b2_vec, (__m512bh)(b_i16_vec), (__m512bh)(b_i16_vec)); + if (n) + goto simsimd_cos_bf16_genoa_cycle; + + simsimd_f32_t ab = _mm512_reduce_add_ps(ab_vec); + simsimd_f32_t a2 = _mm512_reduce_add_ps(a2_vec); + simsimd_f32_t b2 = _mm512_reduce_add_ps(b2_vec); + + // Compute the reciprocal square roots of a2 and b2 + __m128 rsqrts = _mm_rsqrt14_ps(_mm_set_ps(0.f, 0.f, a2 + 1.e-9f, b2 + 1.e-9f)); + simsimd_f32_t rsqrt_a2 = _mm_cvtss_f32(rsqrts); + simsimd_f32_t rsqrt_b2 = _mm_cvtss_f32(_mm_shuffle_ps(rsqrts, rsqrts, _MM_SHUFFLE(0, 0, 0, 1))); + *result = ab != 0 ? 1 - ab * rsqrt_a2 * rsqrt_b2 : 0; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_GENOA + +#if SIMSIMD_TARGET_SAPPHIRE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512fp16") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512fp16"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_l2sq_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512h d2_vec = _mm512_setzero_ph(); + __m512i a_i16_vec, b_i16_vec; + +simsimd_l2sq_f16_sapphire_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); + b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); + n = 0; + } else { + a_i16_vec = _mm512_loadu_epi16(a); + b_i16_vec = _mm512_loadu_epi16(b); + a += 32, b += 32, n -= 32; + } + __m512h d_vec = _mm512_sub_ph(_mm512_castsi512_ph(a_i16_vec), _mm512_castsi512_ph(b_i16_vec)); + d2_vec = _mm512_fmadd_ph(d_vec, d_vec, d2_vec); + if (n) + goto simsimd_l2sq_f16_sapphire_cycle; + + *result = _mm512_reduce_add_ph(d2_vec); +} + +SIMSIMD_PUBLIC void simsimd_cos_f16_sapphire(simsimd_f16_t const* a, simsimd_f16_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512h ab_vec = _mm512_setzero_ph(); + __m512h a2_vec = _mm512_setzero_ph(); + __m512h b2_vec = _mm512_setzero_ph(); + __m512i a_i16_vec, b_i16_vec; + +simsimd_cos_f16_sapphire_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_i16_vec = _mm512_maskz_loadu_epi16(mask, a); + b_i16_vec = _mm512_maskz_loadu_epi16(mask, b); + n = 0; + } else { + a_i16_vec = _mm512_loadu_epi16(a); + b_i16_vec = _mm512_loadu_epi16(b); + a += 32, b += 32, n -= 32; + } + ab_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_i16_vec), _mm512_castsi512_ph(b_i16_vec), ab_vec); + a2_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(a_i16_vec), _mm512_castsi512_ph(a_i16_vec), a2_vec); + b2_vec = _mm512_fmadd_ph(_mm512_castsi512_ph(b_i16_vec), _mm512_castsi512_ph(b_i16_vec), b2_vec); + if (n) + goto simsimd_cos_f16_sapphire_cycle; + + simsimd_f32_t ab = _mm512_reduce_add_ph(ab_vec); + simsimd_f32_t a2 = _mm512_reduce_add_ph(a2_vec); + simsimd_f32_t b2 = _mm512_reduce_add_ph(b2_vec); + + // Compute the reciprocal square roots of a2 and b2 + __m128 rsqrts = _mm_rsqrt14_ps(_mm_set_ps(0.f, 0.f, a2 + 1.e-9f, b2 + 1.e-9f)); + simsimd_f32_t rsqrt_a2 = _mm_cvtss_f32(rsqrts); + simsimd_f32_t rsqrt_b2 = _mm_cvtss_f32(_mm_shuffle_ps(rsqrts, rsqrts, _MM_SHUFFLE(0, 0, 0, 1))); + *result = ab != 0 ? 1 - ab * rsqrt_a2 * rsqrt_b2 : 0; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_SAPPHIRE + +#if SIMSIMD_TARGET_ICE +#pragma GCC push_options +#pragma GCC target("avx512f", "avx512vl", "bmi2", "avx512bw", "avx512vnni") +#pragma clang attribute push(__attribute__((target("avx512f,avx512vl,bmi2,avx512bw,avx512vnni"))), apply_to = function) + +SIMSIMD_PUBLIC void simsimd_l2sq_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512i d2_i32s_vec = _mm512_setzero_si512(); + __m512i a_vec, b_vec, d_i16s_vec; + +simsimd_l2sq_i8_ice_cycle: + if (n < 32) { + __mmask32 mask = (__mmask32)_bzhi_u32(0xFFFFFFFF, n); + a_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, a)); + b_vec = _mm512_cvtepi8_epi16(_mm256_maskz_loadu_epi8(mask, b)); + n = 0; + } else { + a_vec = _mm512_cvtepi8_epi16(_mm256_loadu_epi8(a)); + b_vec = _mm512_cvtepi8_epi16(_mm256_loadu_epi8(b)); + a += 32, b += 32, n -= 32; + } + d_i16s_vec = _mm512_sub_epi16(a_vec, b_vec); + d2_i32s_vec = _mm512_dpwssd_epi32(d2_i32s_vec, d_i16s_vec, d_i16s_vec); + if (n) + goto simsimd_l2sq_i8_ice_cycle; + + *result = _mm512_reduce_add_epi32(d2_i32s_vec); +} + +SIMSIMD_PUBLIC void simsimd_cos_i8_ice(simsimd_i8_t const* a, simsimd_i8_t const* b, simsimd_size_t n, + simsimd_distance_t* result) { + __m512i ab_low_i32s_vec = _mm512_setzero_si512(); + __m512i ab_high_i32s_vec = _mm512_setzero_si512(); + __m512i a2_i32s_vec = _mm512_setzero_si512(); + __m512i b2_i32s_vec = _mm512_setzero_si512(); + __m512i a_vec, b_vec; + __m512i a_abs_vec, b_abs_vec; + +simsimd_cos_i8_ice_cycle: + if (n < 64) { + __mmask64 mask = (__mmask64)_bzhi_u64(0xFFFFFFFFFFFFFFFF, n); + a_vec = _mm512_maskz_loadu_epi8(mask, a); + b_vec = _mm512_maskz_loadu_epi8(mask, b); + n = 0; + } else { + a_vec = _mm512_loadu_epi8(a); + b_vec = _mm512_loadu_epi8(b); + a += 64, b += 64, n -= 64; + } + + // We can't directly use the `_mm512_dpbusd_epi32` intrinsic everywhere, + // as it's asymmetric with respect to the sign of the input arguments: + // Signed(ZeroExtend16(a.byte[4*j]) * SignExtend16(b.byte[4*j])) + // Luckily to compute the squares, we just drop the sign bit of the second argument. + a_abs_vec = _mm512_abs_epi8(a_vec); + b_abs_vec = _mm512_abs_epi8(b_vec); + a2_i32s_vec = _mm512_dpbusds_epi32(a2_i32s_vec, a_abs_vec, a_abs_vec); + b2_i32s_vec = _mm512_dpbusds_epi32(b2_i32s_vec, b_abs_vec, b_abs_vec); + + // The same trick won't work for the primary dot-product, as the signs vector + // components may differ significantly. So we have to use two `_mm512_dpwssd_epi32` + // intrinsics instead, upcasting four chunks to 16-bit integers beforehand! + ab_low_i32s_vec = _mm512_dpwssds_epi32( // + ab_low_i32s_vec, // + _mm512_cvtepi8_epi16(_mm512_castsi512_si256(a_vec)), // + _mm512_cvtepi8_epi16(_mm512_castsi512_si256(b_vec))); + ab_high_i32s_vec = _mm512_dpwssds_epi32( // + ab_high_i32s_vec, // + _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(a_vec, 1)), // + _mm512_cvtepi8_epi16(_mm512_extracti64x4_epi64(b_vec, 1))); + if (n) + goto simsimd_cos_i8_ice_cycle; + + int ab = _mm512_reduce_add_epi32(_mm512_add_epi32(ab_low_i32s_vec, ab_high_i32s_vec)); + int a2 = _mm512_reduce_add_epi32(a2_i32s_vec); + int b2 = _mm512_reduce_add_epi32(b2_i32s_vec); + + // Compute the reciprocal square roots of a2 and b2 + // Mysteriously, MSVC has no `_mm_rsqrt14_ps` intrinsic, but has it's masked variants, + // so let's use `_mm_maskz_rsqrt14_ps(0xFF, ...)` instead. + __m128 rsqrts = _mm_maskz_rsqrt14_ps(0xFF, _mm_set_ps(0.f, 0.f, a2 + 1.e-9f, b2 + 1.e-9f)); + simsimd_f32_t rsqrt_a2 = _mm_cvtss_f32(rsqrts); + simsimd_f32_t rsqrt_b2 = _mm_cvtss_f32(_mm_shuffle_ps(rsqrts, rsqrts, _MM_SHUFFLE(0, 0, 0, 1))); + *result = ab != 0 ? 1 - ab * rsqrt_a2 * rsqrt_b2 : 0; +} + +#pragma clang attribute pop +#pragma GCC pop_options +#endif // SIMSIMD_TARGET_ICE +#endif // SIMSIMD_TARGET_X86 + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/src/include/simsimd/types.h b/src/include/simsimd/types.h new file mode 100644 index 0000000..1ae2663 --- /dev/null +++ b/src/include/simsimd/types.h @@ -0,0 +1,423 @@ +/** + * @file types.h + * @brief Shared definitions for the SimSIMD library. + * @author Ash Vardanian + * @date October 2, 2023 + * + * Defines: + * - Sized aliases for numeric types, like: `simsimd_i32_t` and `simsimd_f64_t`. + * - Macros for compiler/hardware checks, like: `SIMSIMD_TARGET_NEON` + */ +#ifndef SIMSIMD_TYPES_H +#define SIMSIMD_TYPES_H + +/* Annotation for the public API symbols: + * + * - `SIMSIMD_PUBLIC` is used for functions that are part of the public API. + * - `SIMSIMD_INTERNAL` is used for internal helper functions with unstable APIs. + * - `SIMSIMD_DYNAMIC` is used for functions that are part of the public API, but are dispatched at runtime. + */ +#if defined(_WIN32) || defined(__CYGWIN__) +#define SIMSIMD_DYNAMIC __declspec(dllexport) +#define SIMSIMD_PUBLIC inline static +#define SIMSIMD_INTERNAL inline static +#elif defined(__GNUC__) || defined(__clang__) +#define SIMSIMD_DYNAMIC __attribute__((visibility("default"))) +#define SIMSIMD_PUBLIC __attribute__((unused)) inline static +#define SIMSIMD_INTERNAL __attribute__((always_inline)) inline static +#else +#define SIMSIMD_DYNAMIC inline static +#define SIMSIMD_PUBLIC inline static +#define SIMSIMD_INTERNAL inline static +#endif + +// Compiling for Arm: SIMSIMD_TARGET_ARM +#if !defined(SIMSIMD_TARGET_ARM) +#if defined(__aarch64__) || defined(_M_ARM64) +#define SIMSIMD_TARGET_ARM 1 +#else +#define SIMSIMD_TARGET_ARM 0 +#endif // defined(__aarch64__) || defined(_M_ARM64) +#endif // !defined(SIMSIMD_TARGET_ARM) + +// Compiling for x86: SIMSIMD_TARGET_X86 +#if !defined(SIMSIMD_TARGET_X86) +#if defined(__x86_64__) || defined(_M_X64) +#define SIMSIMD_TARGET_X86 1 +#else +#define SIMSIMD_TARGET_X86 0 +#endif // defined(__x86_64__) || defined(_M_X64) +#endif // !defined(SIMSIMD_TARGET_X86) + +// Compiling for Arm: SIMSIMD_TARGET_NEON +#if !defined(SIMSIMD_TARGET_NEON) || (SIMSIMD_TARGET_NEON && !SIMSIMD_TARGET_ARM) +#if defined(__ARM_NEON) +#define SIMSIMD_TARGET_NEON SIMSIMD_TARGET_ARM +#else +#undef SIMSIMD_TARGET_NEON +#define SIMSIMD_TARGET_NEON 0 +#endif // defined(__ARM_NEON) +#endif // !defined(SIMSIMD_TARGET_NEON) + +// Compiling for Arm: SIMSIMD_TARGET_SVE +#if !defined(SIMSIMD_TARGET_SVE) || (SIMSIMD_TARGET_SVE && !SIMSIMD_TARGET_ARM) +#if defined(__ARM_FEATURE_SVE) +#define SIMSIMD_TARGET_SVE SIMSIMD_TARGET_ARM +#else +#undef SIMSIMD_TARGET_SVE +#define SIMSIMD_TARGET_SVE 0 +#endif // defined(__ARM_FEATURE_SVE) +#endif // !defined(SIMSIMD_TARGET_SVE) + +// Compiling for x86: SIMSIMD_TARGET_HASWELL +// +// Starting with Ivy Bridge, Intel supports the `F16C` extensions for fast half-precision +// to single-precision floating-point conversions. On AMD those instructions +// are supported on all CPUs starting with Jaguar 2009. +// Starting with Sandy Bridge, Intel adds basic AVX support in their CPUs and in 2013 +// extends it with AVX2 in the Haswell generation. Moreover, Haswell adds FMA support. +#if !defined(SIMSIMD_TARGET_HASWELL) || (SIMSIMD_TARGET_HASWELL && !SIMSIMD_TARGET_X86) +#if defined(__AVX2__) && defined(__FMA__) && defined(__F16C__) +#define SIMSIMD_TARGET_HASWELL 1 +#else +#undef SIMSIMD_TARGET_HASWELL +#define SIMSIMD_TARGET_HASWELL 0 +#endif // defined(__AVX2__) +#endif // !defined(SIMSIMD_TARGET_HASWELL) + +// Compiling for x86: SIMSIMD_TARGET_SKYLAKE, SIMSIMD_TARGET_ICE, SIMSIMD_TARGET_SAPPHIRE +// +// It's important to provide fine-grained controls over AVX512 families, as they are very fragmented: +// - Intel Skylake servers: F, CD, VL, DQ, BW +// - Intel Cascade Lake workstations: F, CD, VL, DQ, BW, VNNI +// > In other words, it extends Skylake with VNNI support +// - Intel Sunny Cove (Ice Lake) servers: +// F, CD, VL, DQ, BW, VNNI, VPOPCNTDQ, IFMA, VBMI, VAES, GFNI, VBMI2, BITALG, VPCLMULQDQ +// - AMD Zen4 (Genoa): +// F, CD, VL, DQ, BW, VNNI, VPOPCNTDQ, IFMA, VBMI, VAES, GFNI, VBMI2, BITALG, VPCLMULQDQ, BF16 +// > In other words, it extends Sunny Cove with BF16 support +// - Golden Cove (Sapphire Rapids): extends Zen4 and Sunny Cove with FP16 support +// +// Intel Palm Cove was an irrelevant intermediate release extending Skylake with IFMA and VBMI. +// Intel Willow Cove was an irrelevant intermediate release extending Sunny Cove with VP2INTERSECT, +// that aren't supported by any other CPU built to date... and those are only available in Tiger Lake laptops. +// Intel Cooper Lake was the only intermediary platform, that supported BF16, but not FP16. +// It's mostly used in 4-socket and 8-socket high-memory configurations. +// +// In practical terms, it makes sense to differentiate only 3 AVX512 generations: +// 1. Skylake (pre 2019): supports single-precision dot-products. +// 2. Ice Lake (2019-2021): advanced integer algorithms. +// 3. Sapphire Rapids (2023+): advanced mixed-precision float processing. +// +// To list all available macros for x86, take a recent compiler, like GCC 12 and run: +// gcc-12 -march=sapphirerapids -dM -E - < /dev/null | egrep "SSE|AVX" | sort +// On Arm machines you may want to check for other flags: +// gcc-12 -march=native -dM -E - < /dev/null | egrep "NEON|SVE|FP16|FMA" | sort +#if !defined(SIMSIMD_TARGET_SKYLAKE) || (SIMSIMD_TARGET_SKYLAKE && !SIMSIMD_TARGET_X86) +#if defined(__AVX512F__) && defined(__AVX512CD__) && defined(__AVX512VL__) && defined(__AVX512DQ__) && \ + defined(__AVX512BW__) +#define SIMSIMD_TARGET_SKYLAKE 1 +#else +#undef SIMSIMD_TARGET_SKYLAKE +#define SIMSIMD_TARGET_SKYLAKE 0 +#endif +#endif // !defined(SIMSIMD_TARGET_SKYLAKE) +#if !defined(SIMSIMD_TARGET_ICE) || (SIMSIMD_TARGET_ICE && !SIMSIMD_TARGET_X86) +#if defined(__AVX512VNNI__) && defined(__AVX512IFMA__) && defined(__AVX512BITALG__) && defined(__AVX512VBMI2__) && \ + defined(__AVX512VPOPCNTDQ__) +#define SIMSIMD_TARGET_ICE 1 +#else +#undef SIMSIMD_TARGET_ICE +#define SIMSIMD_TARGET_ICE 0 +#endif +#endif // !defined(SIMSIMD_TARGET_ICE) +#if !defined(SIMSIMD_TARGET_GENOA) || (SIMSIMD_TARGET_GENOA && !SIMSIMD_TARGET_X86) +#if defined(__AVX512BF16__) +#define SIMSIMD_TARGET_GENOA 1 +#else +#undef SIMSIMD_TARGET_GENOA +#define SIMSIMD_TARGET_GENOA 0 +#endif +#endif // !defined(SIMSIMD_TARGET_GENOA) +#if !defined(SIMSIMD_TARGET_SAPPHIRE) || (SIMSIMD_TARGET_SAPPHIRE && !SIMSIMD_TARGET_X86) +#if defined(__AVX512FP16__) +#define SIMSIMD_TARGET_SAPPHIRE 1 +#else +#undef SIMSIMD_TARGET_SAPPHIRE +#define SIMSIMD_TARGET_SAPPHIRE 0 +#endif +#endif // !defined(SIMSIMD_TARGET_SAPPHIRE) + +#ifdef _MSC_VER +#include +#else + +#if SIMSIMD_TARGET_NEON +#include +#endif + +#if SIMSIMD_TARGET_SVE +#include +#endif + +#if SIMSIMD_TARGET_HASWELL || SIMSIMD_TARGET_SKYLAKE +#include +#endif + +#endif + +#ifndef SIMSIMD_RSQRT +#include +#define SIMSIMD_RSQRT(x) (1 / sqrtf(x)) +#endif + +#ifndef SIMSIMD_LOG +#include +#define SIMSIMD_LOG(x) (logf(x)) +#endif + +#ifndef SIMSIMD_F32_DIVISION_EPSILON +#define SIMSIMD_F32_DIVISION_EPSILON (1e-7) +#endif + +#ifndef SIMSIMD_F16_DIVISION_EPSILON +#define SIMSIMD_F16_DIVISION_EPSILON (1e-3) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef int simsimd_i32_t; +typedef float simsimd_f32_t; +typedef double simsimd_f64_t; +typedef signed char simsimd_i8_t; +typedef unsigned char simsimd_b8_t; +typedef long long simsimd_i64_t; +typedef unsigned long long simsimd_u64_t; + +typedef simsimd_u64_t simsimd_size_t; +typedef simsimd_f64_t simsimd_distance_t; + +#if !defined(SIMSIMD_NATIVE_F16) || SIMSIMD_NATIVE_F16 +/** + * @brief Half-precision floating-point type. + * + * - GCC or Clang on 64-bit Arm: `__fp16`, may require `-mfp16-format` option. + * - GCC or Clang on 64-bit x86: `_Float16`. + * - Default: `unsigned short`. + */ +#if (defined(__GNUC__) || defined(__clang__)) && (defined(__ARM_ARCH) || defined(__aarch64__)) && \ + (defined(__ARM_FP16_FORMAT_IEEE)) +#if !defined(SIMSIMD_NATIVE_F16) +#define SIMSIMD_NATIVE_F16 1 +#endif +typedef __fp16 simsimd_f16_t; +#elif ((defined(__GNUC__) || defined(__clang__)) && (defined(__x86_64__) || defined(__i386__)) && \ + (defined(__SSE2__) || defined(__AVX512F__))) +typedef _Float16 simsimd_f16_t; +#if !defined(SIMSIMD_NATIVE_F16) +#define SIMSIMD_NATIVE_F16 1 +#endif +#else // Unknown compiler or architecture +#define SIMSIMD_NATIVE_F16 0 +#endif // Unknown compiler or architecture +#endif // !SIMSIMD_NATIVE_F16 + +#if !SIMSIMD_NATIVE_F16 +typedef unsigned short simsimd_f16_t; +#endif + +#if !defined(SIMSIMD_NATIVE_BF16) || SIMSIMD_NATIVE_BF16 +/** + * @brief Half-precision brain-float type. + * + * - GCC or Clang on 64-bit Arm: `__bf16` + * - GCC or Clang on 64-bit x86: `_BFloat16`. + * - Default: `unsigned short`. + * + * @warning Apple Clang has hard time with bf16. + * https://developer.apple.com/documentation/xcode/writing-arm64-code-for-apple-platforms + * https://forums.developer.apple.com/forums/thread/726201 + */ +#if (defined(__GNUC__) || defined(__clang__)) && (defined(__ARM_ARCH) || defined(__aarch64__)) && \ + (defined(__ARM_BF16_FORMAT_ALTERNATIVE)) +#if !defined(SIMSIMD_NATIVE_BF16) +#define SIMSIMD_NATIVE_BF16 1 +#endif +typedef __fp16 simsimd_bf16_t; +#elif ((defined(__GNUC__) || defined(__clang__)) && (defined(__x86_64__) || defined(__i386__)) && \ + (defined(__SSE2__) || defined(__AVX512F__))) +typedef _Float16 simsimd_bf16_t; +#if !defined(SIMSIMD_NATIVE_BF16) +#define SIMSIMD_NATIVE_BF16 1 +#endif +#else // Unknown compiler or architecture +#define SIMSIMD_NATIVE_BF16 0 +#endif // Unknown compiler or architecture +#endif // !SIMSIMD_NATIVE_BF16 + +#if !SIMSIMD_NATIVE_BF16 +typedef unsigned short simsimd_bf16_t; +#endif + +/** + * @brief Alias for the half-precision floating-point type on Arm. + * + * Clang and GCC bring the `float16_t` symbol when you compile for Aarch64. + * MSVC lacks it, and it's `vld1_f16`-like intrinsics are in reality macros, + * that cast to 16-bit integers internally, instead of using floats. + * Some of those are defined as aliases, so we use `#define` preprocessor + * directives instead of `typedef` to avoid errors. + */ +#if SIMSIMD_TARGET_ARM +#if defined(_MSC_VER) +#define simsimd_f16_for_arm_simd_t simsimd_f16_t +#define simsimd_bf16_for_arm_simd_t simsimd_bf16_t +#else +#define simsimd_f16_for_arm_simd_t float16_t +#define simsimd_bf16_for_arm_simd_t bfloat16_t +#endif +#endif + +#define SIMSIMD_IDENTIFY(x) (x) + +/** + * @brief Returns the value of the half-precision floating-point number, + * potentially decompressed into single-precision. + */ +#ifndef SIMSIMD_UNCOMPRESS_F16 +#if SIMSIMD_NATIVE_F16 +#define SIMSIMD_UNCOMPRESS_F16(x) (SIMSIMD_IDENTIFY(x)) +#else +#define SIMSIMD_UNCOMPRESS_F16(x) (simsimd_uncompress_f16(x)) +#endif +#endif + +/** + * @brief Returns the value of the half-precision brain floating-point number, + * potentially decompressed into single-precision. + */ +#ifndef SIMSIMD_UNCOMPRESS_BF16 +#if SIMSIMD_NATIVE_BF16 +#define SIMSIMD_UNCOMPRESS_BF16(x) (SIMSIMD_IDENTIFY(x)) +#else +#define SIMSIMD_UNCOMPRESS_BF16(x) (simsimd_uncompress_bf16(x)) +#endif +#endif + +typedef union { + unsigned i; + float f; +} simsimd_f32i32_t; + +/** + * @brief Computes `1/sqrt(x)` using the trick from Quake 3, replacing + * magic numbers with the ones suggested by Jan Kadlec. + */ +SIMSIMD_PUBLIC simsimd_f32_t simsimd_approximate_inverse_square_root(simsimd_f32_t number) { + simsimd_f32i32_t conv; + conv.f = number; + conv.i = 0x5F1FFFF9 - (conv.i >> 1); + conv.f *= 0.703952253f * (2.38924456f - number * conv.f * conv.f); + return conv.f; +} + +/** + * @brief Computes `log(x)` using the Mercator series. + * The series converges to the natural logarithm for args between -1 and 1. + * Published in 1668 in "Logarithmotechnia". + */ +SIMSIMD_PUBLIC simsimd_f32_t simsimd_approximate_log(simsimd_f32_t number) { + simsimd_f32_t x = number - 1; + simsimd_f32_t x2 = x * x; + simsimd_f32_t x3 = x * x * x; + return x - x2 / 2 + x3 / 3; +} + +/** + * @brief For compilers that don't natively support the `_Float16` type, + * upcasts contents into a more conventional `float`. + * + * @warning This function won't handle boundary conditions well. + * + * https://stackoverflow.com/a/60047308 + * https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233 + * https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834 + */ +SIMSIMD_PUBLIC simsimd_f32_t simsimd_uncompress_f16(unsigned short x) { + union float_or_unsigned_int_t { + float f; + unsigned int i; + }; + unsigned int exponent = (x & 0x7C00) >> 10; + unsigned int mantissa = (x & 0x03FF) << 13; + union float_or_unsigned_int_t mantissa_union; + mantissa_union.f = (float)mantissa; + unsigned int v = (mantissa_union.i) >> 23; + union float_or_unsigned_int_t result_union; + result_union.i = (x & 0x8000) << 16 | (exponent != 0) * ((exponent + 112) << 23 | mantissa) | + ((exponent == 0) & (mantissa != 0)) * ((v - 37) << 23 | ((mantissa << (150 - v)) & 0x007FE000)); + return result_union.f; +} + +/** + * @brief Compresses a `float` to an `f16` representation (IEEE-754 16-bit floating-point format). + * + * @warning This function won't handle boundary conditions well. + * + * https://stackoverflow.com/a/60047308 + * https://gist.github.com/milhidaka/95863906fe828198f47991c813dbe233 + * https://github.com/OpenCyphal/libcanard/blob/636795f4bc395f56af8d2c61d3757b5e762bb9e5/canard.c#L811-L834 + */ +SIMSIMD_PUBLIC unsigned short simsimd_compress_f16(simsimd_f32_t x) { + union float_or_unsigned_int_t { + float f; + unsigned int i; + }; + + unsigned int b = *(unsigned int*)&x + 0x00001000; + unsigned int e = (b & 0x7F800000) >> 23; + unsigned int m = b & 0x007FFFFF; + unsigned short result = (b & 0x80000000) >> 16 | (e > 112) * (((e - 112) << 10) & 0x7C00 | m >> 13) | + ((e < 113) & (e > 101)) * (((0x007FF000 + m) >> ((125 - e) + 1)) >> 1) | (e > 143) * 0x7FFF; + return result; +} + +/** + * @brief For compilers that don't natively support the `__bf16` type, + * upcasts contents into a more conventional `float`. + * + * https://stackoverflow.com/questions/55253233/convert-fp32-to-bfloat16-in-c/55254307#55254307 + * https://cloud.google.com/blog/products/ai-machine-learning/bfloat16-the-secret-to-high-performance-on-cloud-tpus + */ +SIMSIMD_PUBLIC simsimd_f32_t simsimd_uncompress_bf16(unsigned short x) { + union float_or_unsigned_int_t { + float f; + unsigned int i; + }; + union float_or_unsigned_int_t result_union; + result_union.i = x << 16; // Zero extends the mantissa + return result_union.f; +} + +/** + * @brief Compresses a `float` to a `bf16` representation. + */ +SIMSIMD_PUBLIC unsigned short simsimd_compress_bf16(simsimd_f32_t x) { + union float_or_unsigned_int_t { + float f; + unsigned int i; + }; + union float_or_unsigned_int_t value; + value.f = x; + value.i >>= 16; + value.i &= 0xFFFF; + return (unsigned short)value.i; +} + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/src/include/usearch/duckdb_usearch.hpp b/src/include/usearch/duckdb_usearch.hpp index 5102394..4914fdb 100644 --- a/src/include/usearch/duckdb_usearch.hpp +++ b/src/include/usearch/duckdb_usearch.hpp @@ -3,10 +3,14 @@ #include #include -#define USEARCH_USE_SIMSIMD 0 +#define USEARCH_USE_SIMSIMD DUCKDB_USEARCH_USE_SIMSIMD #define USEARCH_USE_FP16LIB 1 #define USEARCH_USE_OPENMP 0 #include "usearch/index.hpp" #include "usearch/index_dense.hpp" -#include "usearch/index_plugins.hpp" \ No newline at end of file +#include "usearch/index_plugins.hpp" + +#undef USEARCH_USE_SIMSIMD +#undef USEARCH_USE_FP16LIB +#undef USEARCH_USE_OPENMP \ No newline at end of file diff --git a/src/include/usearch/index.hpp b/src/include/usearch/index.hpp index 6e5150d..ef79586 100644 --- a/src/include/usearch/index.hpp +++ b/src/include/usearch/index.hpp @@ -1,17 +1,15 @@ /** -* @file index.hpp -* @author Ash Vardanian -* @brief Single-header Vector Search. -* @date 2023-04-26 -* -* @copyright Copyright (c) 2023 -*/ + * @file index.hpp + * @author Ash Vardanian + * @brief Single-header Vector Search engine. + * @date April 26, 2023 + */ #ifndef UNUM_USEARCH_HPP #define UNUM_USEARCH_HPP #define USEARCH_VERSION_MAJOR 2 -#define USEARCH_VERSION_MINOR 9 -#define USEARCH_VERSION_PATCH 2 +#define USEARCH_VERSION_MINOR 12 +#define USEARCH_VERSION_PATCH 0 // Inferring C++ version // https://stackoverflow.com/a/61552074 @@ -114,58 +112,72 @@ #define usearch_noexcept_m noexcept #else #define usearch_assert_m(must_be_true, message) \ - if (!(must_be_true)) { \ - throw std::runtime_error(message); \ - } + if (!(must_be_true)) { \ + __usearch_raise_runtime_error(message); \ + } #define usearch_noexcept_m #endif +extern "C" { +/// @brief Helper function to simplify debugging - trace just one symbol - `__usearch_raise_runtime_error`. +/// Assuming the `extern C` block, the name won't be mangled. +inline static void __usearch_raise_runtime_error(char const* message) { + // On Windows we compile with `/EHc` flag, which specifies that functions + // with C linkage do not throw C++ exceptions. +#if !defined(__cpp_exceptions) || defined(USEARCH_DEFINED_WINDOWS) + std::terminate(); +#else + throw std::runtime_error(message); +#endif +} +} + namespace unum { namespace usearch { using byte_t = char; template std::size_t divide_round_up(std::size_t num) noexcept { - return (num + multiple_ak - 1) / multiple_ak; + return (num + multiple_ak - 1) / multiple_ak; } inline std::size_t divide_round_up(std::size_t num, std::size_t denominator) noexcept { - return (num + denominator - 1) / denominator; + return (num + denominator - 1) / denominator; } inline std::size_t ceil2(std::size_t v) noexcept { - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; #ifdef USEARCH_64BIT_ENV - v |= v >> 32; + v |= v >> 32; #endif - v++; - return v; + v++; + return v; } /// @brief Simply dereferencing misaligned pointers can be dangerous. template void misaligned_store(void* ptr, at v) noexcept { - static_assert(!std::is_reference::value, "Can't store a reference"); - std::memcpy(ptr, &v, sizeof(at)); + static_assert(!std::is_reference::value, "Can't store a reference"); + std::memcpy(ptr, &v, sizeof(at)); } /// @brief Simply dereferencing misaligned pointers can be dangerous. template at misaligned_load(void* ptr) noexcept { - static_assert(!std::is_reference::value, "Can't load a reference"); - at v; - std::memcpy(&v, ptr, sizeof(at)); - return v; + static_assert(!std::is_reference::value, "Can't load a reference"); + at v; + std::memcpy(&v, ptr, sizeof(at)); + return v; } /// @brief The `std::exchange` alternative for C++11. template at exchange(at& obj, other_at&& new_value) { - at old_value = std::move(obj); - obj = std::forward(new_value); - return old_value; + at old_value = std::move(obj); + obj = std::forward(new_value); + return old_value; } /// @brief The `std::destroy_at` alternative for C++11. @@ -173,7 +185,7 @@ template typename std::enable_if::value>::type destroy_at(at*) {} template typename std::enable_if::value>::type destroy_at(at* obj) { - obj->~sfinae_at(); + obj->~sfinae_at(); } /// @brief The `std::construct_at` alternative for C++11. @@ -181,576 +193,576 @@ template typename std::enable_if::value>::type construct_at(at*) {} template typename std::enable_if::value>::type construct_at(at* obj) { - new (obj) at(); + new (obj) at(); } /** -* @brief A reference to a misaligned memory location with a specific type. -* It is needed to avoid Undefined Behavior when dereferencing addresses -* indivisible by `sizeof(at)`. -*/ + * @brief A reference to a misaligned memory location with a specific type. + * It is needed to avoid Undefined Behavior when dereferencing addresses + * indivisible by `sizeof(at)`. + */ template class misaligned_ref_gt { - using element_t = at; - using mutable_t = typename std::remove_const::type; - byte_t* ptr_; - -public: - misaligned_ref_gt(byte_t* ptr) noexcept : ptr_(ptr) {} - operator mutable_t() const noexcept { return misaligned_load(ptr_); } - misaligned_ref_gt& operator=(mutable_t const& v) noexcept { - misaligned_store(ptr_, v); - return *this; - } - - void reset(byte_t* ptr) noexcept { ptr_ = ptr; } - byte_t* ptr() const noexcept { return ptr_; } + using element_t = at; + using mutable_t = typename std::remove_const::type; + byte_t* ptr_; + + public: + misaligned_ref_gt(byte_t* ptr) noexcept : ptr_(ptr) {} + operator mutable_t() const noexcept { return misaligned_load(ptr_); } + misaligned_ref_gt& operator=(mutable_t const& v) noexcept { + misaligned_store(ptr_, v); + return *this; + } + + void reset(byte_t* ptr) noexcept { ptr_ = ptr; } + byte_t* ptr() const noexcept { return ptr_; } }; /** -* @brief A pointer to a misaligned memory location with a specific type. -* It is needed to avoid Undefined Behavior when dereferencing addresses -* indivisible by `sizeof(at)`. -*/ + * @brief A pointer to a misaligned memory location with a specific type. + * It is needed to avoid Undefined Behavior when dereferencing addresses + * indivisible by `sizeof(at)`. + */ template class misaligned_ptr_gt { - using element_t = at; - using mutable_t = typename std::remove_const::type; - byte_t* ptr_; - -public: - using iterator_category = std::random_access_iterator_tag; - using value_type = element_t; - using difference_type = std::ptrdiff_t; - using pointer = misaligned_ptr_gt; - using reference = misaligned_ref_gt; - - reference operator*() const noexcept { return {ptr_}; } - reference operator[](std::size_t i) noexcept { return reference(ptr_ + i * sizeof(element_t)); } - value_type operator[](std::size_t i) const noexcept { - return misaligned_load(ptr_ + i * sizeof(element_t)); - } - - misaligned_ptr_gt(byte_t* ptr) noexcept : ptr_(ptr) {} - misaligned_ptr_gt operator++(int) noexcept { return misaligned_ptr_gt(ptr_ + sizeof(element_t)); } - misaligned_ptr_gt operator--(int) noexcept { return misaligned_ptr_gt(ptr_ - sizeof(element_t)); } - misaligned_ptr_gt operator+(difference_type d) noexcept { return misaligned_ptr_gt(ptr_ + d * sizeof(element_t)); } - misaligned_ptr_gt operator-(difference_type d) noexcept { return misaligned_ptr_gt(ptr_ - d * sizeof(element_t)); } - - // clang-format off - misaligned_ptr_gt& operator++() noexcept { ptr_ += sizeof(element_t); return *this; } - misaligned_ptr_gt& operator--() noexcept { ptr_ -= sizeof(element_t); return *this; } - misaligned_ptr_gt& operator+=(difference_type d) noexcept { ptr_ += d * sizeof(element_t); return *this; } - misaligned_ptr_gt& operator-=(difference_type d) noexcept { ptr_ -= d * sizeof(element_t); return *this; } - // clang-format on - - bool operator==(misaligned_ptr_gt const& other) noexcept { return ptr_ == other.ptr_; } - bool operator!=(misaligned_ptr_gt const& other) noexcept { return ptr_ != other.ptr_; } + using element_t = at; + using mutable_t = typename std::remove_const::type; + byte_t* ptr_; + + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = misaligned_ptr_gt; + using reference = misaligned_ref_gt; + + reference operator*() const noexcept { return {ptr_}; } + reference operator[](std::size_t i) noexcept { return reference(ptr_ + i * sizeof(element_t)); } + value_type operator[](std::size_t i) const noexcept { + return misaligned_load(ptr_ + i * sizeof(element_t)); + } + + misaligned_ptr_gt(byte_t* ptr) noexcept : ptr_(ptr) {} + misaligned_ptr_gt operator++(int) noexcept { return misaligned_ptr_gt(ptr_ + sizeof(element_t)); } + misaligned_ptr_gt operator--(int) noexcept { return misaligned_ptr_gt(ptr_ - sizeof(element_t)); } + misaligned_ptr_gt operator+(difference_type d) noexcept { return misaligned_ptr_gt(ptr_ + d * sizeof(element_t)); } + misaligned_ptr_gt operator-(difference_type d) noexcept { return misaligned_ptr_gt(ptr_ - d * sizeof(element_t)); } + + // clang-format off + misaligned_ptr_gt& operator++() noexcept { ptr_ += sizeof(element_t); return *this; } + misaligned_ptr_gt& operator--() noexcept { ptr_ -= sizeof(element_t); return *this; } + misaligned_ptr_gt& operator+=(difference_type d) noexcept { ptr_ += d * sizeof(element_t); return *this; } + misaligned_ptr_gt& operator-=(difference_type d) noexcept { ptr_ -= d * sizeof(element_t); return *this; } + // clang-format on + + bool operator==(misaligned_ptr_gt const& other) noexcept { return ptr_ == other.ptr_; } + bool operator!=(misaligned_ptr_gt const& other) noexcept { return ptr_ != other.ptr_; } }; /** -* @brief Non-owning memory range view, similar to `std::span`, but for C++11. -*/ + * @brief Non-owning memory range view, similar to `std::span`, but for C++11. + */ template class span_gt { - scalar_at* data_; - std::size_t size_; - -public: - span_gt() noexcept : data_(nullptr), size_(0u) {} - span_gt(scalar_at* begin, scalar_at* end) noexcept : data_(begin), size_(end - begin) {} - span_gt(scalar_at* begin, std::size_t size) noexcept : data_(begin), size_(size) {} - scalar_at* data() const noexcept { return data_; } - std::size_t size() const noexcept { return size_; } - scalar_at* begin() const noexcept { return data_; } - scalar_at* end() const noexcept { return data_ + size_; } - operator scalar_at*() const noexcept { return data(); } + scalar_at* data_; + std::size_t size_; + + public: + span_gt() noexcept : data_(nullptr), size_(0u) {} + span_gt(scalar_at* begin, scalar_at* end) noexcept : data_(begin), size_(end - begin) {} + span_gt(scalar_at* begin, std::size_t size) noexcept : data_(begin), size_(size) {} + scalar_at* data() const noexcept { return data_; } + std::size_t size() const noexcept { return size_; } + scalar_at* begin() const noexcept { return data_; } + scalar_at* end() const noexcept { return data_ + size_; } + operator scalar_at*() const noexcept { return data(); } }; /** -* @brief Similar to `std::vector`, but doesn't support dynamic resizing. -* On the bright side, this can't throw exceptions. -*/ + * @brief Similar to `std::vector`, but doesn't support dynamic resizing. + * On the bright side, this can't throw exceptions. + */ template > class buffer_gt { - scalar_at* data_; - std::size_t size_; - -public: - buffer_gt() noexcept : data_(nullptr), size_(0u) {} - buffer_gt(std::size_t size) noexcept : data_(allocator_at{}.allocate(size)), size_(data_ ? size : 0u) { - if (!std::is_trivially_default_constructible::value) - for (std::size_t i = 0; i != size_; ++i) - construct_at(data_ + i); - } - ~buffer_gt() noexcept { - if (!std::is_trivially_destructible::value) - for (std::size_t i = 0; i != size_; ++i) - destroy_at(data_ + i); - allocator_at{}.deallocate(data_, size_); - data_ = nullptr; - size_ = 0; - } - scalar_at* data() const noexcept { return data_; } - std::size_t size() const noexcept { return size_; } - scalar_at* begin() const noexcept { return data_; } - scalar_at* end() const noexcept { return data_ + size_; } - operator scalar_at*() const noexcept { return data(); } - scalar_at& operator[](std::size_t i) noexcept { return data_[i]; } - scalar_at const& operator[](std::size_t i) const noexcept { return data_[i]; } - explicit operator bool() const noexcept { return data_; } - scalar_at* release() noexcept { - size_ = 0; - return exchange(data_, nullptr); - } - - buffer_gt(buffer_gt const&) = delete; - buffer_gt& operator=(buffer_gt const&) = delete; - - buffer_gt(buffer_gt&& other) noexcept : data_(exchange(other.data_, nullptr)), size_(exchange(other.size_, 0)) {} - buffer_gt& operator=(buffer_gt&& other) noexcept { - std::swap(data_, other.data_); - std::swap(size_, other.size_); - return *this; - } + scalar_at* data_; + std::size_t size_; + + public: + buffer_gt() noexcept : data_(nullptr), size_(0u) {} + buffer_gt(std::size_t size) noexcept : data_(allocator_at{}.allocate(size)), size_(data_ ? size : 0u) { + if (!std::is_trivially_default_constructible::value) + for (std::size_t i = 0; i != size_; ++i) + construct_at(data_ + i); + } + ~buffer_gt() noexcept { + if (!std::is_trivially_destructible::value) + for (std::size_t i = 0; i != size_; ++i) + destroy_at(data_ + i); + allocator_at{}.deallocate(data_, size_); + data_ = nullptr; + size_ = 0; + } + scalar_at* data() const noexcept { return data_; } + std::size_t size() const noexcept { return size_; } + scalar_at* begin() const noexcept { return data_; } + scalar_at* end() const noexcept { return data_ + size_; } + operator scalar_at*() const noexcept { return data(); } + scalar_at& operator[](std::size_t i) noexcept { return data_[i]; } + scalar_at const& operator[](std::size_t i) const noexcept { return data_[i]; } + explicit operator bool() const noexcept { return data_; } + scalar_at* release() noexcept { + size_ = 0; + return exchange(data_, nullptr); + } + + buffer_gt(buffer_gt const&) = delete; + buffer_gt& operator=(buffer_gt const&) = delete; + + buffer_gt(buffer_gt&& other) noexcept : data_(exchange(other.data_, nullptr)), size_(exchange(other.size_, 0)) {} + buffer_gt& operator=(buffer_gt&& other) noexcept { + std::swap(data_, other.data_); + std::swap(size_, other.size_); + return *this; + } }; /** -* @brief A lightweight error class for handling error messages, -* which are expected to be allocated in static memory. -*/ + * @brief A lightweight error class for handling error messages, + * which are expected to be allocated in static memory. + */ class error_t { - char const* message_{}; - -public: - error_t(char const* message = nullptr) noexcept : message_(message) {} - error_t& operator=(char const* message) noexcept { - message_ = message; - return *this; - } - - error_t(error_t const&) = delete; - error_t& operator=(error_t const&) = delete; - error_t(error_t&& other) noexcept : message_(exchange(other.message_, nullptr)) {} - error_t& operator=(error_t&& other) noexcept { - std::swap(message_, other.message_); - return *this; - } - explicit operator bool() const noexcept { return message_ != nullptr; } - char const* what() const noexcept { return message_; } - char const* release() noexcept { return exchange(message_, nullptr); } + char const* message_{}; + + public: + error_t(char const* message = nullptr) noexcept : message_(message) {} + error_t& operator=(char const* message) noexcept { + message_ = message; + return *this; + } + + error_t(error_t const&) = delete; + error_t& operator=(error_t const&) = delete; + error_t(error_t&& other) noexcept : message_(exchange(other.message_, nullptr)) {} + error_t& operator=(error_t&& other) noexcept { + std::swap(message_, other.message_); + return *this; + } + explicit operator bool() const noexcept { return message_ != nullptr; } + char const* what() const noexcept { return message_; } + char const* release() noexcept { return exchange(message_, nullptr); } #if defined(__cpp_exceptions) || defined(__EXCEPTIONS) - ~error_t() noexcept(false) { + ~error_t() noexcept(false) { #if defined(USEARCH_DEFINED_CPP17) - if (message_ && std::uncaught_exceptions() == 0) + if (message_ && std::uncaught_exceptions() == 0) #else - if (message_ && std::uncaught_exception() == 0) + if (message_ && std::uncaught_exception() == 0) #endif - raise(); - } - void raise() noexcept(false) { - if (message_) - throw std::runtime_error(exchange(message_, nullptr)); - } + raise(); + } + void raise() noexcept(false) { + if (message_) + throw std::runtime_error(exchange(message_, nullptr)); + } #else - ~error_t() noexcept { raise(); } - void raise() noexcept { - if (message_) - std::terminate(); - } + ~error_t() noexcept { raise(); } + void raise() noexcept { + if (message_) + std::terminate(); + } #endif }; /** -* @brief Similar to `std::expected` in C++23, wraps a statement evaluation result, -* or an error. It's used to avoid raising exception, and gracefully propagate -* the error. -* -* @tparam result_at The type of the expected result. -*/ + * @brief Similar to `std::expected` in C++23, wraps a statement evaluation result, + * or an error. It's used to avoid raising exception, and gracefully propagate + * the error. + * + * @tparam result_at The type of the expected result. + */ template struct expected_gt { - result_at result; - error_t error; - - operator result_at&() & { - error.raise(); - return result; - } - operator result_at&&() && { - error.raise(); - return std::move(result); - } - result_at const& operator*() const noexcept { return result; } - explicit operator bool() const noexcept { return !error; } - expected_gt failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } + result_at result; + error_t error; + + operator result_at&() & { + error.raise(); + return result; + } + operator result_at&&() && { + error.raise(); + return std::move(result); + } + result_at const& operator*() const noexcept { return result; } + explicit operator bool() const noexcept { return !error; } + expected_gt failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } }; /** -* @brief Light-weight bitset implementation to sync nodes updates during graph mutations. -* Extends basic functionality with @b atomic operations. -*/ + * @brief Light-weight bitset implementation to sync nodes updates during graph mutations. + * Extends basic functionality with @b atomic operations. + */ template > class bitset_gt { - using allocator_t = allocator_at; - using byte_t = typename allocator_t::value_type; - static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); - - using compressed_slot_t = unsigned long; - - static constexpr std::size_t bits_per_slot() { return sizeof(compressed_slot_t) * CHAR_BIT; } - static constexpr compressed_slot_t bits_mask() { return sizeof(compressed_slot_t) * CHAR_BIT - 1; } - static constexpr std::size_t slots(std::size_t bits) { return divide_round_up(bits); } - - compressed_slot_t* slots_{}; - /// @brief Number of slots. - std::size_t count_{}; - -public: - bitset_gt() noexcept {} - ~bitset_gt() noexcept { reset(); } - - explicit operator bool() const noexcept { return slots_; } - void clear() noexcept { - if (slots_) - std::memset(slots_, 0, count_ * sizeof(compressed_slot_t)); - } - - void reset() noexcept { - if (slots_) - allocator_t{}.deallocate((byte_t*)slots_, count_ * sizeof(compressed_slot_t)); - slots_ = nullptr; - count_ = 0; - } - - bitset_gt(std::size_t capacity) noexcept - : slots_((compressed_slot_t*)allocator_t{}.allocate(slots(capacity) * sizeof(compressed_slot_t))), - count_(slots_ ? slots(capacity) : 0u) { - clear(); - } - - bitset_gt(bitset_gt&& other) noexcept { - slots_ = exchange(other.slots_, nullptr); - count_ = exchange(other.count_, 0); - } - - bitset_gt& operator=(bitset_gt&& other) noexcept { - std::swap(slots_, other.slots_); - std::swap(count_, other.count_); - return *this; - } - - bitset_gt(bitset_gt const&) = delete; - bitset_gt& operator=(bitset_gt const&) = delete; - - inline bool test(std::size_t i) const noexcept { return slots_[i / bits_per_slot()] & (1ul << (i & bits_mask())); } - inline bool set(std::size_t i) noexcept { - compressed_slot_t& slot = slots_[i / bits_per_slot()]; - compressed_slot_t mask{1ul << (i & bits_mask())}; - bool value = slot & mask; - slot |= mask; - return value; - } + using allocator_t = allocator_at; + using byte_t = typename allocator_t::value_type; + static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); + + using compressed_slot_t = unsigned long; + + static constexpr std::size_t bits_per_slot() { return sizeof(compressed_slot_t) * CHAR_BIT; } + static constexpr compressed_slot_t bits_mask() { return sizeof(compressed_slot_t) * CHAR_BIT - 1; } + static constexpr std::size_t slots(std::size_t bits) { return divide_round_up(bits); } + + compressed_slot_t* slots_{}; + /// @brief Number of slots. + std::size_t count_{}; + + public: + bitset_gt() noexcept {} + ~bitset_gt() noexcept { reset(); } + + explicit operator bool() const noexcept { return slots_; } + void clear() noexcept { + if (slots_) + std::memset(slots_, 0, count_ * sizeof(compressed_slot_t)); + } + + void reset() noexcept { + if (slots_) + allocator_t{}.deallocate((byte_t*)slots_, count_ * sizeof(compressed_slot_t)); + slots_ = nullptr; + count_ = 0; + } + + bitset_gt(std::size_t capacity) noexcept + : slots_((compressed_slot_t*)allocator_t{}.allocate(slots(capacity) * sizeof(compressed_slot_t))), + count_(slots_ ? slots(capacity) : 0u) { + clear(); + } + + bitset_gt(bitset_gt&& other) noexcept { + slots_ = exchange(other.slots_, nullptr); + count_ = exchange(other.count_, 0); + } + + bitset_gt& operator=(bitset_gt&& other) noexcept { + std::swap(slots_, other.slots_); + std::swap(count_, other.count_); + return *this; + } + + bitset_gt(bitset_gt const&) = delete; + bitset_gt& operator=(bitset_gt const&) = delete; + + inline bool test(std::size_t i) const noexcept { return slots_[i / bits_per_slot()] & (1ul << (i & bits_mask())); } + inline bool set(std::size_t i) noexcept { + compressed_slot_t& slot = slots_[i / bits_per_slot()]; + compressed_slot_t mask{1ul << (i & bits_mask())}; + bool value = slot & mask; + slot |= mask; + return value; + } #if defined(USEARCH_DEFINED_WINDOWS) - inline bool atomic_set(std::size_t i) noexcept { - compressed_slot_t mask{1ul << (i & bits_mask())}; - return InterlockedOr((long volatile*)&slots_[i / bits_per_slot()], mask) & mask; - } + inline bool atomic_set(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + return InterlockedOr((long volatile*)&slots_[i / bits_per_slot()], mask) & mask; + } - inline void atomic_reset(std::size_t i) noexcept { - compressed_slot_t mask{1ul << (i & bits_mask())}; - InterlockedAnd((long volatile*)&slots_[i / bits_per_slot()], ~mask); - } + inline void atomic_reset(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + InterlockedAnd((long volatile*)&slots_[i / bits_per_slot()], ~mask); + } #else - inline bool atomic_set(std::size_t i) noexcept { - compressed_slot_t mask{1ul << (i & bits_mask())}; - return __atomic_fetch_or(&slots_[i / bits_per_slot()], mask, __ATOMIC_ACQUIRE) & mask; - } + inline bool atomic_set(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + return __atomic_fetch_or(&slots_[i / bits_per_slot()], mask, __ATOMIC_ACQUIRE) & mask; + } - inline void atomic_reset(std::size_t i) noexcept { - compressed_slot_t mask{1ul << (i & bits_mask())}; - __atomic_fetch_and(&slots_[i / bits_per_slot()], ~mask, __ATOMIC_RELEASE); - } + inline void atomic_reset(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + __atomic_fetch_and(&slots_[i / bits_per_slot()], ~mask, __ATOMIC_RELEASE); + } #endif - class lock_t { - bitset_gt& bitset_; - std::size_t bit_offset_; + class lock_t { + bitset_gt& bitset_; + std::size_t bit_offset_; - public: - inline ~lock_t() noexcept { bitset_.atomic_reset(bit_offset_); } - inline lock_t(bitset_gt& bitset, std::size_t bit_offset) noexcept : bitset_(bitset), bit_offset_(bit_offset) { - while (bitset_.atomic_set(bit_offset_)) - ; - } - }; + public: + inline ~lock_t() noexcept { bitset_.atomic_reset(bit_offset_); } + inline lock_t(bitset_gt& bitset, std::size_t bit_offset) noexcept : bitset_(bitset), bit_offset_(bit_offset) { + while (bitset_.atomic_set(bit_offset_)) + ; + } + }; - inline lock_t lock(std::size_t i) noexcept { return {*this, i}; } + inline lock_t lock(std::size_t i) noexcept { return {*this, i}; } }; using bitset_t = bitset_gt<>; /** -* @brief Similar to `std::priority_queue`, but allows raw access to underlying -* memory, in case you want to shuffle it or sort. Good for collections -* from 100s to 10'000s elements. -*/ + * @brief Similar to `std::priority_queue`, but allows raw access to underlying + * memory, in case you want to shuffle it or sort. Good for collections + * from 100s to 10'000s elements. + */ template , // is needed before C++14. - typename allocator_at = std::allocator> // + typename comparator_at = std::less, // is needed before C++14. + typename allocator_at = std::allocator> // class max_heap_gt { -public: - using element_t = element_at; - using comparator_t = comparator_at; - using allocator_t = allocator_at; - - using value_type = element_t; - - static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); - static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); - -private: - element_t* elements_; - std::size_t size_; - std::size_t capacity_; - -public: - max_heap_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) {} - - max_heap_gt(max_heap_gt&& other) noexcept - : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), - capacity_(exchange(other.capacity_, 0)) {} - - max_heap_gt& operator=(max_heap_gt&& other) noexcept { - std::swap(elements_, other.elements_); - std::swap(size_, other.size_); - std::swap(capacity_, other.capacity_); - return *this; - } - - max_heap_gt(max_heap_gt const&) = delete; - max_heap_gt& operator=(max_heap_gt const&) = delete; - - ~max_heap_gt() noexcept { reset(); } - - void reset() noexcept { - if (elements_) - allocator_t{}.deallocate(elements_, capacity_); - elements_ = nullptr; - capacity_ = 0; - size_ = 0; - } - - inline bool empty() const noexcept { return !size_; } - inline std::size_t size() const noexcept { return size_; } - inline std::size_t capacity() const noexcept { return capacity_; } - - /// @brief Selects the largest element in the heap. - /// @return Reference to the stored element. - inline element_t const& top() const noexcept { return elements_[0]; } - inline void clear() noexcept { size_ = 0; } - - bool reserve(std::size_t new_capacity) noexcept { - if (new_capacity < capacity_) - return true; - - new_capacity = ceil2(new_capacity); - new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); - auto allocator = allocator_t{}; - auto new_elements = allocator.allocate(new_capacity); - if (!new_elements) - return false; - - if (elements_) { - std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); - allocator.deallocate(elements_, capacity_); - } - elements_ = new_elements; - capacity_ = new_capacity; - return new_elements; - } - - bool insert(element_t&& element) noexcept { - if (!reserve(size_ + 1)) - return false; - - insert_reserved(std::move(element)); - return true; - } - - inline void insert_reserved(element_t&& element) noexcept { - new (&elements_[size_]) element_t(element); - size_++; - shift_up(size_ - 1); - } - - inline element_t pop() noexcept { - element_t result = top(); - std::swap(elements_[0], elements_[size_ - 1]); - size_--; - elements_[size_].~element_t(); - shift_down(0); - return result; - } - - /** @brief Invalidates the "max-heap" property, transforming into ascending range. */ - inline void sort_ascending() noexcept { std::sort_heap(elements_, elements_ + size_, &less); } - inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } - - inline element_t* data() noexcept { return elements_; } - inline element_t const* data() const noexcept { return elements_; } - -private: - inline std::size_t parent_idx(std::size_t i) const noexcept { return (i - 1u) / 2u; } - inline std::size_t left_child_idx(std::size_t i) const noexcept { return (i * 2u) + 1u; } - inline std::size_t right_child_idx(std::size_t i) const noexcept { return (i * 2u) + 2u; } - static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } - - void shift_up(std::size_t i) noexcept { - for (; i && less(elements_[parent_idx(i)], elements_[i]); i = parent_idx(i)) - std::swap(elements_[parent_idx(i)], elements_[i]); - } - - void shift_down(std::size_t i) noexcept { - std::size_t max_idx = i; - - std::size_t left = left_child_idx(i); - if (left < size_ && less(elements_[max_idx], elements_[left])) - max_idx = left; - - std::size_t right = right_child_idx(i); - if (right < size_ && less(elements_[max_idx], elements_[right])) - max_idx = right; - - if (i != max_idx) { - std::swap(elements_[i], elements_[max_idx]); - shift_down(max_idx); - } - } + public: + using element_t = element_at; + using comparator_t = comparator_at; + using allocator_t = allocator_at; + + using value_type = element_t; + + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + + private: + element_t* elements_; + std::size_t size_; + std::size_t capacity_; + + public: + max_heap_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) {} + + max_heap_gt(max_heap_gt&& other) noexcept + : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), + capacity_(exchange(other.capacity_, 0)) {} + + max_heap_gt& operator=(max_heap_gt&& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + return *this; + } + + max_heap_gt(max_heap_gt const&) = delete; + max_heap_gt& operator=(max_heap_gt const&) = delete; + + ~max_heap_gt() noexcept { reset(); } + + void reset() noexcept { + if (elements_) + allocator_t{}.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + size_ = 0; + } + + inline bool empty() const noexcept { return !size_; } + inline std::size_t size() const noexcept { return size_; } + inline std::size_t capacity() const noexcept { return capacity_; } + + /// @brief Selects the largest element in the heap. + /// @return Reference to the stored element. + inline element_t const& top() const noexcept { return elements_[0]; } + inline void clear() noexcept { size_ = 0; } + + bool reserve(std::size_t new_capacity) noexcept { + if (new_capacity < capacity_) + return true; + + new_capacity = ceil2(new_capacity); + new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); + auto allocator = allocator_t{}; + auto new_elements = allocator.allocate(new_capacity); + if (!new_elements) + return false; + + if (elements_) { + std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); + allocator.deallocate(elements_, capacity_); + } + elements_ = new_elements; + capacity_ = new_capacity; + return new_elements; + } + + bool insert(element_t&& element) noexcept { + if (!reserve(size_ + 1)) + return false; + + insert_reserved(std::move(element)); + return true; + } + + inline void insert_reserved(element_t&& element) noexcept { + new (&elements_[size_]) element_t(element); + size_++; + shift_up(size_ - 1); + } + + inline element_t pop() noexcept { + element_t result = top(); + std::swap(elements_[0], elements_[size_ - 1]); + size_--; + elements_[size_].~element_t(); + shift_down(0); + return result; + } + + /** @brief Invalidates the "max-heap" property, transforming into ascending range. */ + inline void sort_ascending() noexcept { std::sort_heap(elements_, elements_ + size_, &less); } + inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } + + inline element_t* data() noexcept { return elements_; } + inline element_t const* data() const noexcept { return elements_; } + + private: + inline std::size_t parent_idx(std::size_t i) const noexcept { return (i - 1u) / 2u; } + inline std::size_t left_child_idx(std::size_t i) const noexcept { return (i * 2u) + 1u; } + inline std::size_t right_child_idx(std::size_t i) const noexcept { return (i * 2u) + 2u; } + static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } + + void shift_up(std::size_t i) noexcept { + for (; i && less(elements_[parent_idx(i)], elements_[i]); i = parent_idx(i)) + std::swap(elements_[parent_idx(i)], elements_[i]); + } + + void shift_down(std::size_t i) noexcept { + std::size_t max_idx = i; + + std::size_t left = left_child_idx(i); + if (left < size_ && less(elements_[max_idx], elements_[left])) + max_idx = left; + + std::size_t right = right_child_idx(i); + if (right < size_ && less(elements_[max_idx], elements_[right])) + max_idx = right; + + if (i != max_idx) { + std::swap(elements_[i], elements_[max_idx]); + shift_down(max_idx); + } + } }; /** -* @brief Similar to `std::priority_queue`, but allows raw access to underlying -* memory and always keeps the data sorted. Ideal for small collections -* under 128 elements. -*/ + * @brief Similar to `std::priority_queue`, but allows raw access to underlying + * memory and always keeps the data sorted. Ideal for small collections + * under 128 elements. + */ template , // is needed before C++14. - typename allocator_at = std::allocator> // + typename comparator_at = std::less, // is needed before C++14. + typename allocator_at = std::allocator> // class sorted_buffer_gt { -public: - using element_t = element_at; - using comparator_t = comparator_at; - using allocator_t = allocator_at; - - static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); - static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); - - using value_type = element_t; - -private: - element_t* elements_; - std::size_t size_; - std::size_t capacity_; - -public: - sorted_buffer_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) {} - - sorted_buffer_gt(sorted_buffer_gt&& other) noexcept - : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), - capacity_(exchange(other.capacity_, 0)) {} - - sorted_buffer_gt& operator=(sorted_buffer_gt&& other) noexcept { - std::swap(elements_, other.elements_); - std::swap(size_, other.size_); - std::swap(capacity_, other.capacity_); - return *this; - } - - sorted_buffer_gt(sorted_buffer_gt const&) = delete; - sorted_buffer_gt& operator=(sorted_buffer_gt const&) = delete; - - ~sorted_buffer_gt() noexcept { reset(); } - - void reset() noexcept { - if (elements_) - allocator_t{}.deallocate(elements_, capacity_); - elements_ = nullptr; - capacity_ = 0; - size_ = 0; - } - - inline bool empty() const noexcept { return !size_; } - inline std::size_t size() const noexcept { return size_; } - inline std::size_t capacity() const noexcept { return capacity_; } - inline element_t const& top() const noexcept { return elements_[size_ - 1]; } - inline void clear() noexcept { size_ = 0; } - - bool reserve(std::size_t new_capacity) noexcept { - if (new_capacity < capacity_) - return true; - - new_capacity = ceil2(new_capacity); - new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); - auto allocator = allocator_t{}; - auto new_elements = allocator.allocate(new_capacity); - if (!new_elements) - return false; - - if (size_) - std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); - if (elements_) - allocator.deallocate(elements_, capacity_); - - elements_ = new_elements; - capacity_ = new_capacity; - return true; - } - - inline void insert_reserved(element_t&& element) noexcept { - std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; - std::size_t to_move = size_ - slot; - element_t* source = elements_ + size_ - 1; - for (; to_move; --to_move, --source) - source[1] = source[0]; - elements_[slot] = element; - size_++; - } - - /** - * @return `true` if the entry was added, `false` if it wasn't relevant enough. - */ - inline bool insert(element_t&& element, std::size_t limit) noexcept { - std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; - if (slot == limit) - return false; - std::size_t to_move = size_ - slot - (size_ == limit); - element_t* source = elements_ + size_ - 1 - (size_ == limit); - for (; to_move; --to_move, --source) - source[1] = source[0]; - elements_[slot] = element; - size_ += size_ != limit; - return true; - } - - inline element_t pop() noexcept { - size_--; - element_t result = elements_[size_]; - elements_[size_].~element_t(); - return result; - } - - void sort_ascending() noexcept {} - inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } - - inline element_t* data() noexcept { return elements_; } - inline element_t const* data() const noexcept { return elements_; } - -private: - static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } + public: + using element_t = element_at; + using comparator_t = comparator_at; + using allocator_t = allocator_at; + + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + + using value_type = element_t; + + private: + element_t* elements_; + std::size_t size_; + std::size_t capacity_; + + public: + sorted_buffer_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) {} + + sorted_buffer_gt(sorted_buffer_gt&& other) noexcept + : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), + capacity_(exchange(other.capacity_, 0)) {} + + sorted_buffer_gt& operator=(sorted_buffer_gt&& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + return *this; + } + + sorted_buffer_gt(sorted_buffer_gt const&) = delete; + sorted_buffer_gt& operator=(sorted_buffer_gt const&) = delete; + + ~sorted_buffer_gt() noexcept { reset(); } + + void reset() noexcept { + if (elements_) + allocator_t{}.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + size_ = 0; + } + + inline bool empty() const noexcept { return !size_; } + inline std::size_t size() const noexcept { return size_; } + inline std::size_t capacity() const noexcept { return capacity_; } + inline element_t const& top() const noexcept { return elements_[size_ - 1]; } + inline void clear() noexcept { size_ = 0; } + + bool reserve(std::size_t new_capacity) noexcept { + if (new_capacity < capacity_) + return true; + + new_capacity = ceil2(new_capacity); + new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); + auto allocator = allocator_t{}; + auto new_elements = allocator.allocate(new_capacity); + if (!new_elements) + return false; + + if (size_) + std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); + if (elements_) + allocator.deallocate(elements_, capacity_); + + elements_ = new_elements; + capacity_ = new_capacity; + return true; + } + + inline void insert_reserved(element_t&& element) noexcept { + std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; + std::size_t to_move = size_ - slot; + element_t* source = elements_ + size_ - 1; + for (; to_move; --to_move, --source) + source[1] = source[0]; + elements_[slot] = element; + size_++; + } + + /** + * @return `true` if the entry was added, `false` if it wasn't relevant enough. + */ + inline bool insert(element_t&& element, std::size_t limit) noexcept { + std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; + if (slot == limit) + return false; + std::size_t to_move = size_ - slot - (size_ == limit); + element_t* source = elements_ + size_ - 1 - (size_ == limit); + for (; to_move; --to_move, --source) + source[1] = source[0]; + elements_[slot] = element; + size_ += size_ != limit; + return true; + } + + inline element_t pop() noexcept { + size_--; + element_t result = elements_[size_]; + elements_[size_].~element_t(); + return result; + } + + void sort_ascending() noexcept {} + inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } + + inline element_t* data() noexcept { return elements_; } + inline element_t const* data() const noexcept { return elements_; } + + private: + static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } }; #if defined(USEARCH_DEFINED_WINDOWS) @@ -758,53 +770,53 @@ class sorted_buffer_gt { #endif /** -* @brief Five-byte integer type to address node clouds with over 4B entries. -* -* @note Avoid usage in 32bit environment -*/ + * @brief Five-byte integer type to address node clouds with over 4B entries. + * + * @note Avoid usage in 32bit environment + */ class usearch_pack_m uint40_t { - unsigned char octets[5]; + unsigned char octets[5]; - inline uint40_t& broadcast(unsigned char c) { - std::memset(octets, c, 5); - return *this; - } + inline uint40_t& broadcast(unsigned char c) { + std::memset(octets, c, 5); + return *this; + } -public: - inline uint40_t() noexcept { broadcast(0); } - inline uint40_t(std::uint32_t n) noexcept { std::memcpy(&octets[1], &n, 4); } + public: + inline uint40_t() noexcept { broadcast(0); } + inline uint40_t(std::uint32_t n) noexcept { std::memcpy(&octets[1], &n, 4); } #ifdef USEARCH_64BIT_ENV - inline uint40_t(std::uint64_t n) noexcept { std::memcpy(octets, &n, 5); } + inline uint40_t(std::uint64_t n) noexcept { std::memcpy(octets, &n, 5); } #endif - uint40_t(uint40_t&&) = default; - uint40_t(uint40_t const&) = default; - uint40_t& operator=(uint40_t&&) = default; - uint40_t& operator=(uint40_t const&) = default; + uint40_t(uint40_t&&) = default; + uint40_t(uint40_t const&) = default; + uint40_t& operator=(uint40_t&&) = default; + uint40_t& operator=(uint40_t const&) = default; #if defined(USEARCH_DEFINED_CLANG) && defined(USEARCH_DEFINED_APPLE) - inline uint40_t(std::size_t n) noexcept { + inline uint40_t(std::size_t n) noexcept { #ifdef USEARCH_64BIT_ENV - std::memcpy(octets, &n, 5); + std::memcpy(octets, &n, 5); #else - std::memcpy(octets, &n, 4); + std::memcpy(octets, &n, 4); #endif - } + } #endif - inline operator std::size_t() const noexcept { - std::size_t result = 0; + inline operator std::size_t() const noexcept { + std::size_t result = 0; #ifdef USEARCH_64BIT_ENV - std::memcpy(&result, octets, 5); + std::memcpy(&result, octets, 5); #else - std::memcpy(&result, octets + 1, 4); + std::memcpy(&result, octets + 1, 4); #endif - return result; - } + return result; + } - inline static uint40_t max() noexcept { return uint40_t{}.broadcast(0xFF); } - inline static uint40_t min() noexcept { return uint40_t{}.broadcast(0); } + inline static uint40_t max() noexcept { return uint40_t{}.broadcast(0xFF); } + inline static uint40_t min() noexcept { return uint40_t{}.broadcast(0); } }; #if defined(USEARCH_DEFINED_WINDOWS) @@ -820,263 +832,263 @@ template ::va // clang-format on template struct hash_gt { - std::size_t operator()(element_at const& element) const noexcept { return std::hash{}(element); } + std::size_t operator()(element_at const& element) const noexcept { return std::hash{}(element); } }; template <> struct hash_gt { - std::size_t operator()(uint40_t const& element) const noexcept { return std::hash{}(element); } + std::size_t operator()(uint40_t const& element) const noexcept { return std::hash{}(element); } }; /** -* @brief Minimalistic hash-set implementation to track visited nodes during graph traversal. -* -* It doesn't support deletion of separate objects, but supports `clear`-ing all at once. -* It expects `reserve` to be called ahead of all insertions, so no resizes are needed. -* It also assumes `0xFF...FF` slots to be unused, to simplify the design. -* It uses linear probing, the number of slots is always a power of two, and it uses linear-probing -* in case of bucket collisions. -*/ + * @brief Minimalistic hash-set implementation to track visited nodes during graph traversal. + * + * It doesn't support deletion of separate objects, but supports `clear`-ing all at once. + * It expects `reserve` to be called ahead of all insertions, so no resizes are needed. + * It also assumes `0xFF...FF` slots to be unused, to simplify the design. + * It uses linear probing, the number of slots is always a power of two, and it uses linear-probing + * in case of bucket collisions. + */ template , typename allocator_at = std::allocator> class growing_hash_set_gt { - using element_t = element_at; - using hasher_t = hasher_at; - - using allocator_t = allocator_at; - using byte_t = typename allocator_t::value_type; - static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); - - element_t* slots_{}; - /// @brief Number of slots. - std::size_t capacity_{}; - /// @brief Number of populated. - std::size_t count_{}; - hasher_t hasher_{}; - -public: - growing_hash_set_gt() noexcept {} - ~growing_hash_set_gt() noexcept { reset(); } - - explicit operator bool() const noexcept { return slots_; } - std::size_t size() const noexcept { return count_; } - - void clear() noexcept { - if (slots_) - std::memset((void*)slots_, 0xFF, capacity_ * sizeof(element_t)); - count_ = 0; - } - - void reset() noexcept { - if (slots_) - allocator_t{}.deallocate((byte_t*)slots_, capacity_ * sizeof(element_t)); - slots_ = nullptr; - capacity_ = 0; - count_ = 0; - } - - growing_hash_set_gt(std::size_t capacity) noexcept - : slots_((element_t*)allocator_t{}.allocate(ceil2(capacity) * sizeof(element_t))), - capacity_(slots_ ? ceil2(capacity) : 0u), count_(0u) { - clear(); - } - - growing_hash_set_gt(growing_hash_set_gt&& other) noexcept { - slots_ = exchange(other.slots_, nullptr); - capacity_ = exchange(other.capacity_, 0); - count_ = exchange(other.count_, 0); - } - - growing_hash_set_gt& operator=(growing_hash_set_gt&& other) noexcept { - std::swap(slots_, other.slots_); - std::swap(capacity_, other.capacity_); - std::swap(count_, other.count_); - return *this; - } - - growing_hash_set_gt(growing_hash_set_gt const&) = delete; - growing_hash_set_gt& operator=(growing_hash_set_gt const&) = delete; - - inline bool test(element_t const& elem) const noexcept { - std::size_t index = hasher_(elem) & (capacity_ - 1); - while (slots_[index] != default_free_value()) { - if (slots_[index] == elem) - return true; - - index = (index + 1) & (capacity_ - 1); - } - return false; - } - - /** - * - * @return Similar to `bitset_gt`, returns the previous value. - */ - inline bool set(element_t const& elem) noexcept { - std::size_t index = hasher_(elem) & (capacity_ - 1); - while (slots_[index] != default_free_value()) { - // Already exists - if (slots_[index] == elem) - return true; - - index = (index + 1) & (capacity_ - 1); - } - slots_[index] = elem; - ++count_; - return false; - } - - bool reserve(std::size_t new_capacity) noexcept { - new_capacity = (new_capacity * 5u) / 3u; - if (new_capacity <= capacity_) - return true; - - new_capacity = ceil2(new_capacity); - element_t* new_slots = (element_t*)allocator_t{}.allocate(new_capacity * sizeof(element_t)); - if (!new_slots) - return false; - - std::memset((void*)new_slots, 0xFF, new_capacity * sizeof(element_t)); - std::size_t new_count = count_; - if (count_) { - for (std::size_t old_index = 0; old_index != capacity_; ++old_index) { - if (slots_[old_index] == default_free_value()) - continue; - - std::size_t new_index = hasher_(slots_[old_index]) & (new_capacity - 1); - while (new_slots[new_index] != default_free_value()) - new_index = (new_index + 1) & (new_capacity - 1); - new_slots[new_index] = slots_[old_index]; - } - } - - reset(); - slots_ = new_slots; - capacity_ = new_capacity; - count_ = new_count; - return true; - } + using element_t = element_at; + using hasher_t = hasher_at; + + using allocator_t = allocator_at; + using byte_t = typename allocator_t::value_type; + static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); + + element_t* slots_{}; + /// @brief Number of slots. + std::size_t capacity_{}; + /// @brief Number of populated. + std::size_t count_{}; + hasher_t hasher_{}; + + public: + growing_hash_set_gt() noexcept {} + ~growing_hash_set_gt() noexcept { reset(); } + + explicit operator bool() const noexcept { return slots_; } + std::size_t size() const noexcept { return count_; } + + void clear() noexcept { + if (slots_) + std::memset((void*)slots_, 0xFF, capacity_ * sizeof(element_t)); + count_ = 0; + } + + void reset() noexcept { + if (slots_) + allocator_t{}.deallocate((byte_t*)slots_, capacity_ * sizeof(element_t)); + slots_ = nullptr; + capacity_ = 0; + count_ = 0; + } + + growing_hash_set_gt(std::size_t capacity) noexcept + : slots_((element_t*)allocator_t{}.allocate(ceil2(capacity) * sizeof(element_t))), + capacity_(slots_ ? ceil2(capacity) : 0u), count_(0u) { + clear(); + } + + growing_hash_set_gt(growing_hash_set_gt&& other) noexcept { + slots_ = exchange(other.slots_, nullptr); + capacity_ = exchange(other.capacity_, 0); + count_ = exchange(other.count_, 0); + } + + growing_hash_set_gt& operator=(growing_hash_set_gt&& other) noexcept { + std::swap(slots_, other.slots_); + std::swap(capacity_, other.capacity_); + std::swap(count_, other.count_); + return *this; + } + + growing_hash_set_gt(growing_hash_set_gt const&) = delete; + growing_hash_set_gt& operator=(growing_hash_set_gt const&) = delete; + + inline bool test(element_t const& elem) const noexcept { + std::size_t index = hasher_(elem) & (capacity_ - 1); + while (slots_[index] != default_free_value()) { + if (slots_[index] == elem) + return true; + + index = (index + 1) & (capacity_ - 1); + } + return false; + } + + /** + * + * @return Similar to `bitset_gt`, returns the previous value. + */ + inline bool set(element_t const& elem) noexcept { + std::size_t index = hasher_(elem) & (capacity_ - 1); + while (slots_[index] != default_free_value()) { + // Already exists + if (slots_[index] == elem) + return true; + + index = (index + 1) & (capacity_ - 1); + } + slots_[index] = elem; + ++count_; + return false; + } + + bool reserve(std::size_t new_capacity) noexcept { + new_capacity = (new_capacity * 5u) / 3u; + if (new_capacity <= capacity_) + return true; + + new_capacity = ceil2(new_capacity); + element_t* new_slots = (element_t*)allocator_t{}.allocate(new_capacity * sizeof(element_t)); + if (!new_slots) + return false; + + std::memset((void*)new_slots, 0xFF, new_capacity * sizeof(element_t)); + std::size_t new_count = count_; + if (count_) { + for (std::size_t old_index = 0; old_index != capacity_; ++old_index) { + if (slots_[old_index] == default_free_value()) + continue; + + std::size_t new_index = hasher_(slots_[old_index]) & (new_capacity - 1); + while (new_slots[new_index] != default_free_value()) + new_index = (new_index + 1) & (new_capacity - 1); + new_slots[new_index] = slots_[old_index]; + } + } + + reset(); + slots_ = new_slots; + capacity_ = new_capacity; + count_ = new_count; + return true; + } }; /** -* @brief Basic single-threaded @b ring class, used for all kinds of task queues. -*/ + * @brief Basic single-threaded @b ring class, used for all kinds of task queues. + */ template > // class ring_gt { -public: - using element_t = element_at; - using allocator_t = allocator_at; - - static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); - static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); - - using value_type = element_t; - -private: - element_t* elements_{}; - std::size_t capacity_{}; - std::size_t head_{}; - std::size_t tail_{}; - bool empty_{true}; - allocator_t allocator_{}; - -public: - explicit ring_gt(allocator_t const& alloc = allocator_t()) noexcept : allocator_(alloc) {} - - ring_gt(ring_gt const&) = delete; - ring_gt& operator=(ring_gt const&) = delete; - - ring_gt(ring_gt&& other) noexcept { swap(other); } - ring_gt& operator=(ring_gt&& other) noexcept { - swap(other); - return *this; - } - - void swap(ring_gt& other) noexcept { - std::swap(elements_, other.elements_); - std::swap(capacity_, other.capacity_); - std::swap(head_, other.head_); - std::swap(tail_, other.tail_); - std::swap(empty_, other.empty_); - std::swap(allocator_, other.allocator_); - } - - ~ring_gt() noexcept { reset(); } - - bool empty() const noexcept { return empty_; } - size_t capacity() const noexcept { return capacity_; } - size_t size() const noexcept { - if (empty_) - return 0; - else if (head_ >= tail_) - return head_ - tail_; - else - return capacity_ - (tail_ - head_); - } - - void clear() noexcept { - head_ = 0; - tail_ = 0; - empty_ = true; - } - - void reset() noexcept { - if (elements_) - allocator_.deallocate(elements_, capacity_); - elements_ = nullptr; - capacity_ = 0; - head_ = 0; - tail_ = 0; - empty_ = true; - } - - bool reserve(std::size_t n) noexcept { - if (n < size()) - return false; // prevent data loss - if (n <= capacity()) - return true; - n = (std::max)(ceil2(n), 64u); - element_t* elements = allocator_.allocate(n); - if (!elements) - return false; - - std::size_t i = 0; - while (try_pop(elements[i])) - i++; - - reset(); - elements_ = elements; - capacity_ = n; - head_ = i; - tail_ = 0; - empty_ = (i == 0); - return true; - } - - void push(element_t const& value) noexcept { - elements_[head_] = value; - head_ = (head_ + 1) % capacity_; - empty_ = false; - } - - bool try_push(element_t const& value) noexcept { - if (head_ == tail_ && !empty_) - return false; // elements_ is full - - return push(value); - return true; - } - - bool try_pop(element_t& value) noexcept { - if (empty_) - return false; - - value = std::move(elements_[tail_]); - tail_ = (tail_ + 1) % capacity_; - empty_ = head_ == tail_; - return true; - } - - element_t const& operator[](std::size_t i) const noexcept { return elements_[(tail_ + i) % capacity_]; } + public: + using element_t = element_at; + using allocator_t = allocator_at; + + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + + using value_type = element_t; + + private: + element_t* elements_{}; + std::size_t capacity_{}; + std::size_t head_{}; + std::size_t tail_{}; + bool empty_{true}; + allocator_t allocator_{}; + + public: + explicit ring_gt(allocator_t const& alloc = allocator_t()) noexcept : allocator_(alloc) {} + + ring_gt(ring_gt const&) = delete; + ring_gt& operator=(ring_gt const&) = delete; + + ring_gt(ring_gt&& other) noexcept { swap(other); } + ring_gt& operator=(ring_gt&& other) noexcept { + swap(other); + return *this; + } + + void swap(ring_gt& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(capacity_, other.capacity_); + std::swap(head_, other.head_); + std::swap(tail_, other.tail_); + std::swap(empty_, other.empty_); + std::swap(allocator_, other.allocator_); + } + + ~ring_gt() noexcept { reset(); } + + bool empty() const noexcept { return empty_; } + size_t capacity() const noexcept { return capacity_; } + size_t size() const noexcept { + if (empty_) + return 0; + else if (head_ >= tail_) + return head_ - tail_; + else + return capacity_ - (tail_ - head_); + } + + void clear() noexcept { + head_ = 0; + tail_ = 0; + empty_ = true; + } + + void reset() noexcept { + if (elements_) + allocator_.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + head_ = 0; + tail_ = 0; + empty_ = true; + } + + bool reserve(std::size_t n) noexcept { + if (n < size()) + return false; // prevent data loss + if (n <= capacity()) + return true; + n = (std::max)(ceil2(n), 64u); + element_t* elements = allocator_.allocate(n); + if (!elements) + return false; + + std::size_t i = 0; + while (try_pop(elements[i])) + i++; + + reset(); + elements_ = elements; + capacity_ = n; + head_ = i; + tail_ = 0; + empty_ = (i == 0); + return true; + } + + void push(element_t const& value) noexcept { + elements_[head_] = value; + head_ = (head_ + 1) % capacity_; + empty_ = false; + } + + bool try_push(element_t const& value) noexcept { + if (head_ == tail_ && !empty_) + return false; // elements_ is full + + return push(value); + return true; + } + + bool try_pop(element_t& value) noexcept { + if (empty_) + return false; + + value = std::move(elements_[tail_]); + tail_ = (tail_ + 1) % capacity_; + empty_ = head_ == tail_; + return true; + } + + element_t const& operator[](std::size_t i) const noexcept { return elements_[(tail_ + i) % capacity_]; } }; /// @brief Number of neighbors per graph node. @@ -1097,476 +1109,478 @@ constexpr std::size_t default_expansion_search() { return 64; } constexpr std::size_t default_allocator_entry_bytes() { return 64; } /** -* @brief Configuration settings for the index construction. -* Includes the main `::connectivity` parameter (`M` in the paper) -* and two expansion factors - for construction and search. -*/ + * @brief Configuration settings for the index construction. + * Includes the main `::connectivity` parameter (`M` in the paper) + * and two expansion factors - for construction and search. + */ struct index_config_t { - /// @brief Number of neighbors per graph node. - /// Defaults to 32 in FAISS and 16 in hnswlib. - /// > It is called `M` in the paper. - std::size_t connectivity = default_connectivity(); - - /// @brief Number of neighbors per graph node in base level graph. - /// Defaults to double of the other levels, so 64 in FAISS and 32 in hnswlib. - /// > It is called `M0` in the paper. - std::size_t connectivity_base = default_connectivity() * 2; - - inline index_config_t() = default; - inline index_config_t(std::size_t c) noexcept - : connectivity(c ? c : default_connectivity()), connectivity_base(c ? c * 2 : default_connectivity() * 2) {} - inline index_config_t(std::size_t c, std::size_t cb) noexcept - : connectivity(c), connectivity_base((std::max)(c, cb)) {} + /// @brief Number of neighbors per graph node. + /// Defaults to 32 in FAISS and 16 in hnswlib. + /// > It is called `M` in the paper. + std::size_t connectivity = default_connectivity(); + + /// @brief Number of neighbors per graph node in base level graph. + /// Defaults to double of the other levels, so 64 in FAISS and 32 in hnswlib. + /// > It is called `M0` in the paper. + std::size_t connectivity_base = default_connectivity() * 2; + + inline index_config_t() = default; + inline index_config_t(std::size_t c) noexcept + : connectivity(c ? c : default_connectivity()), connectivity_base(c ? c * 2 : default_connectivity() * 2) {} + inline index_config_t(std::size_t c, std::size_t cb) noexcept + : connectivity(c), connectivity_base((std::max)(c, cb)) {} }; struct index_limits_t { - std::size_t members = 0; - std::size_t threads_add = std::thread::hardware_concurrency(); - std::size_t threads_search = std::thread::hardware_concurrency(); - - inline index_limits_t(std::size_t n, std::size_t t) noexcept : members(n), threads_add(t), threads_search(t) {} - inline index_limits_t(std::size_t n = 0) noexcept : index_limits_t(n, std::thread::hardware_concurrency()) {} - inline std::size_t threads() const noexcept { return (std::max)(threads_add, threads_search); } - inline std::size_t concurrency() const noexcept { return (std::min)(threads_add, threads_search); } + std::size_t members = 0; + std::size_t threads_add = std::thread::hardware_concurrency(); + std::size_t threads_search = std::thread::hardware_concurrency(); + + inline index_limits_t(std::size_t n, std::size_t t) noexcept : members(n), threads_add(t), threads_search(t) {} + inline index_limits_t(std::size_t n = 0) noexcept : index_limits_t(n, std::thread::hardware_concurrency()) {} + inline std::size_t threads() const noexcept { return (std::max)(threads_add, threads_search); } + inline std::size_t concurrency() const noexcept { return (std::min)(threads_add, threads_search); } }; struct index_update_config_t { - /// @brief Hyper-parameter controlling the quality of indexing. - /// Defaults to 40 in FAISS and 200 in hnswlib. - /// > It is called `efConstruction` in the paper. - std::size_t expansion = default_expansion_add(); + /// @brief Hyper-parameter controlling the quality of indexing. + /// Defaults to 40 in FAISS and 200 in hnswlib. + /// > It is called `efConstruction` in the paper. + std::size_t expansion = default_expansion_add(); - /// @brief Optional thread identifier for multi-threaded construction. - std::size_t thread = 0; + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; }; struct index_search_config_t { - /// @brief Hyper-parameter controlling the quality of search. - /// Defaults to 16 in FAISS and 10 in hnswlib. - /// > It is called `ef` in the paper. - std::size_t expansion = default_expansion_search(); + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); - /// @brief Optional thread identifier for multi-threaded construction. - std::size_t thread = 0; + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; - /// @brief Brute-forces exhaustive search over all entries in the index. - bool exact = false; + /// @brief Brute-forces exhaustive search over all entries in the index. + bool exact = false; }; struct index_cluster_config_t { - /// @brief Hyper-parameter controlling the quality of search. - /// Defaults to 16 in FAISS and 10 in hnswlib. - /// > It is called `ef` in the paper. - std::size_t expansion = default_expansion_search(); + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); - /// @brief Optional thread identifier for multi-threaded construction. - std::size_t thread = 0; + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; }; struct index_copy_config_t {}; struct index_join_config_t { - /// @brief Controls maximum number of proposals per man during stable marriage. - std::size_t max_proposals = 0; + /// @brief Controls maximum number of proposals per man during stable marriage. + std::size_t max_proposals = 0; - /// @brief Hyper-parameter controlling the quality of search. - /// Defaults to 16 in FAISS and 10 in hnswlib. - /// > It is called `ef` in the paper. - std::size_t expansion = default_expansion_search(); + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); - /// @brief Brute-forces exhaustive search over all entries in the index. - bool exact = false; + /// @brief Brute-forces exhaustive search over all entries in the index. + bool exact = false; }; /// @brief C++17 and newer version deprecate the `std::result_of` template using return_type_gt = #if defined(USEARCH_DEFINED_CPP17) - typename std::invoke_result::type; + typename std::invoke_result::type; #else - typename std::result_of::type; + typename std::result_of::type; #endif /** -* @brief An example of what a USearch-compatible ad-hoc filter would look like. -* -* A similar function object can be passed to search queries to further filter entries -* on their auxiliary properties, such as some categorical keys stored in an external DBMS. -*/ + * @brief An example of what a USearch-compatible ad-hoc filter would look like. + * + * A similar function object can be passed to search queries to further filter entries + * on their auxiliary properties, such as some categorical keys stored in an external DBMS. + */ struct dummy_predicate_t { - template constexpr bool operator()(member_at&&) const noexcept { return true; } + template constexpr bool operator()(member_at&&) const noexcept { return true; } }; /** -* @brief An example of what a USearch-compatible ad-hoc operation on in-flight entries. -* -* This kind of callbacks is used when the engine is being updated and you want to patch -* the entries, while their are still under locks - limiting concurrent access and providing -* consistency. -*/ + * @brief An example of what a USearch-compatible ad-hoc operation on in-flight entries. + * + * This kind of callbacks is used when the engine is being updated and you want to patch + * the entries, while their are still under locks - limiting concurrent access and providing + * consistency. + */ struct dummy_callback_t { - template void operator()(member_at&&) const noexcept {} + template void operator()(member_at&&) const noexcept {} }; /** -* @brief An example of what a USearch-compatible progress-bar should look like. -* -* This is particularly helpful when handling long-running tasks, like serialization, -* saving, and loading from disk, or index-level joins. -* The reporter checks return value to continue or stop the process, `false` means need to stop. -*/ + * @brief An example of what a USearch-compatible progress-bar should look like. + * + * This is particularly helpful when handling long-running tasks, like serialization, + * saving, and loading from disk, or index-level joins. + * The reporter checks return value to continue or stop the process, `false` means need to stop. + */ struct dummy_progress_t { - inline bool operator()(std::size_t /*processed*/, std::size_t /*total*/) const noexcept { return true; } + inline bool operator()(std::size_t /*processed*/, std::size_t /*total*/) const noexcept { return true; } }; /** -* @brief An example of what a USearch-compatible values prefetching mechanism should look like. -* -* USearch is designed to handle very large datasets, that may not fir into RAM. Fetching from -* external memory is very expensive, so we've added a pre-fetching mechanism, that accepts -* multiple objects at once, to cache in RAM ahead of the computation. -* The received iterators support both `get_slot` and `get_key` operations. -* An example usage may look like this: -* -* template -* inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept { -* for (; begin != end; ++begin) -* io_uring_prefetch(offset_in_file(get_key(begin))); -* } -*/ + * @brief An example of what a USearch-compatible values prefetching mechanism should look like. + * + * USearch is designed to handle very large datasets, that may not fir into RAM. Fetching from + * external memory is very expensive, so we've added a pre-fetching mechanism, that accepts + * multiple objects at once, to cache in RAM ahead of the computation. + * The received iterators support both `get_slot` and `get_key` operations. + * An example usage may look like this: + * + * template + * inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept { + * for (; begin != end; ++begin) + * io_uring_prefetch(offset_in_file(get_key(begin))); + * } + */ struct dummy_prefetch_t { - template - inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept {} + template + inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept {} }; /** -* @brief An example of what a USearch-compatible executor (thread-pool) should look like. -* -* It's expected to have `parallel(callback)` API to schedule one task per thread; -* an identical `fixed(count, callback)` and `dynamic(count, callback)` overloads that also accepts -* the number of tasks, and somehow schedules them between threads; as well as `size()` to -* determine the number of available threads. -*/ + * @brief An example of what a USearch-compatible executor (thread-pool) should look like. + * + * It's expected to have `parallel(callback)` API to schedule one task per thread; + * an identical `fixed(count, callback)` and `dynamic(count, callback)` overloads that also accepts + * the number of tasks, and somehow schedules them between threads; as well as `size()` to + * determine the number of available threads. + */ struct dummy_executor_t { - dummy_executor_t() noexcept {} - std::size_t size() const noexcept { return 1; } - - template - void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept { - for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) - thread_aware_function(0, task_idx); - } - - template - void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept { - for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) - if (!thread_aware_function(0, task_idx)) - break; - } - - template - void parallel(thread_aware_function_at&& thread_aware_function) noexcept { - thread_aware_function(0); - } + dummy_executor_t() noexcept {} + std::size_t size() const noexcept { return 1; } + + template + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept { + for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) + thread_aware_function(0, task_idx); + } + + template + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept { + for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) + if (!thread_aware_function(0, task_idx)) + break; + } + + template + void parallel(thread_aware_function_at&& thread_aware_function) noexcept { + thread_aware_function(0); + } }; /** -* @brief An example of what a USearch-compatible key-to-key mapping should look like. -* -* This is particularly helpful for "Semantic Joins", where we map entries of one collection -* to entries of another. In asymmetric setups, where A -> B is needed, but B -> A is not, -* this can be passed to minimize memory usage. -*/ + * @brief An example of what a USearch-compatible key-to-key mapping should look like. + * + * This is particularly helpful for "Semantic Joins", where we map entries of one collection + * to entries of another. In asymmetric setups, where A -> B is needed, but B -> A is not, + * this can be passed to minimize memory usage. + */ struct dummy_key_to_key_mapping_t { - struct member_ref_t { - template member_ref_t& operator=(key_at&&) noexcept { return *this; } - }; - template member_ref_t operator[](key_at&&) const noexcept { return {}; } + struct member_ref_t { + template member_ref_t& operator=(key_at&&) noexcept { return *this; } + }; + template member_ref_t operator[](key_at&&) const noexcept { return {}; } }; /** -* @brief Checks if the provided object has a dummy type, emulating an interface, -* but performing no real computation. -*/ + * @brief Checks if the provided object has a dummy type, emulating an interface, + * but performing no real computation. + */ template static constexpr bool is_dummy() { - using object_t = typename std::remove_all_extents::type; - return std::is_same::type, dummy_predicate_t>::value || // - std::is_same::type, dummy_callback_t>::value || // - std::is_same::type, dummy_progress_t>::value || // - std::is_same::type, dummy_prefetch_t>::value || // - std::is_same::type, dummy_executor_t>::value || // - std::is_same::type, dummy_key_to_key_mapping_t>::value; + using object_t = typename std::remove_all_extents::type; + return std::is_same::type, dummy_predicate_t>::value || // + std::is_same::type, dummy_callback_t>::value || // + std::is_same::type, dummy_progress_t>::value || // + std::is_same::type, dummy_prefetch_t>::value || // + std::is_same::type, dummy_executor_t>::value || // + std::is_same::type, dummy_key_to_key_mapping_t>::value; } template struct has_reset_gt { - static_assert(std::integral_constant::value, "Second template parameter needs to be of function type."); + static_assert(std::integral_constant::value, "Second template parameter needs to be of function type."); }; template struct has_reset_gt { -private: - template - static constexpr auto check(at*) -> - typename std::is_same().reset(std::declval()...)), return_at>::type; - template static constexpr std::false_type check(...); + private: + template + static constexpr auto check(at*) -> + typename std::is_same().reset(std::declval()...)), return_at>::type; + template static constexpr std::false_type check(...); - typedef decltype(check(0)) type; + typedef decltype(check(0)) type; -public: - static constexpr bool value = type::value; + public: + static constexpr bool value = type::value; }; /** -* @brief Checks if a certain class has a member function called `reset`. -*/ + * @brief Checks if a certain class has a member function called `reset`. + */ template constexpr bool has_reset() { return has_reset_gt::value; } struct serialization_result_t { - error_t error; + error_t error; - explicit operator bool() const noexcept { return !error; } - serialization_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } + explicit operator bool() const noexcept { return !error; } + serialization_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } }; /** -* @brief Smart-pointer wrapping the LibC @b `FILE` for binary file @b outputs. -* -* This class raises no exceptions and corresponds errors through `serialization_result_t`. -* The class automatically closes the file when the object is destroyed. -*/ + * @brief Smart-pointer wrapping the LibC @b `FILE` for binary file @b outputs. + * + * This class raises no exceptions and corresponds errors through `serialization_result_t`. + * The class automatically closes the file when the object is destroyed. + */ class output_file_t { - char const* path_ = nullptr; - std::FILE* file_ = nullptr; - -public: - output_file_t(char const* path) noexcept : path_(path) {} - ~output_file_t() noexcept { close(); } - output_file_t(output_file_t&& other) noexcept - : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) {} - output_file_t& operator=(output_file_t&& other) noexcept { - std::swap(path_, other.path_); - std::swap(file_, other.file_); - return *this; - } - serialization_result_t open_if_not() noexcept { - serialization_result_t result; - if (!file_) - file_ = std::fopen(path_, "wb"); - if (!file_) - return result.failed(std::strerror(errno)); - return result; - } - serialization_result_t write(void const* begin, std::size_t length) noexcept { - serialization_result_t result; - std::size_t written = std::fwrite(begin, length, 1, file_); - if (length && !written) - return result.failed(std::strerror(errno)); - return result; - } - void close() noexcept { - if (file_) - std::fclose(exchange(file_, nullptr)); - } + char const* path_ = nullptr; + std::FILE* file_ = nullptr; + + public: + output_file_t(char const* path) noexcept : path_(path) {} + ~output_file_t() noexcept { close(); } + output_file_t(output_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) {} + output_file_t& operator=(output_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(file_, other.file_); + return *this; + } + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!file_) + file_ = std::fopen(path_, "wb"); + if (!file_) + return result.failed(std::strerror(errno)); + return result; + } + serialization_result_t write(void const* begin, std::size_t length) noexcept { + serialization_result_t result; + std::size_t written = std::fwrite(begin, length, 1, file_); + if (length && !written) + return result.failed(std::strerror(errno)); + return result; + } + void close() noexcept { + if (file_) + std::fclose(exchange(file_, nullptr)); + } }; /** -* @brief Smart-pointer wrapping the LibC @b `FILE` for binary files @b inputs. -* -* This class raises no exceptions and corresponds errors through `serialization_result_t`. -* The class automatically closes the file when the object is destroyed. -*/ + * @brief Smart-pointer wrapping the LibC @b `FILE` for binary files @b inputs. + * + * This class raises no exceptions and corresponds errors through `serialization_result_t`. + * The class automatically closes the file when the object is destroyed. + */ class input_file_t { - char const* path_ = nullptr; - std::FILE* file_ = nullptr; - -public: - input_file_t(char const* path) noexcept : path_(path) {} - ~input_file_t() noexcept { close(); } - input_file_t(input_file_t&& other) noexcept - : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) {} - input_file_t& operator=(input_file_t&& other) noexcept { - std::swap(path_, other.path_); - std::swap(file_, other.file_); - return *this; - } - - serialization_result_t open_if_not() noexcept { - serialization_result_t result; - if (!file_) - file_ = std::fopen(path_, "rb"); - if (!file_) - return result.failed(std::strerror(errno)); - return result; - } - serialization_result_t read(void* begin, std::size_t length) noexcept { - serialization_result_t result; - std::size_t read = std::fread(begin, length, 1, file_); - if (length && !read) - return result.failed(std::feof(file_) ? "End of file reached!" : std::strerror(errno)); - return result; - } - void close() noexcept { - if (file_) - std::fclose(exchange(file_, nullptr)); - } - - explicit operator bool() const noexcept { return file_; } - bool seek_to(std::size_t progress) noexcept { return std::fseek(file_, progress, SEEK_SET) == 0; } - bool seek_to_end() noexcept { return std::fseek(file_, 0L, SEEK_END) == 0; } - bool infer_progress(std::size_t& progress) noexcept { - long int result = std::ftell(file_); - if (result == -1L) - return false; - progress = static_cast(result); - return true; - } + char const* path_ = nullptr; + std::FILE* file_ = nullptr; + + public: + input_file_t(char const* path) noexcept : path_(path) {} + ~input_file_t() noexcept { close(); } + input_file_t(input_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) {} + input_file_t& operator=(input_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(file_, other.file_); + return *this; + } + + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!file_) + file_ = std::fopen(path_, "rb"); + if (!file_) + return result.failed(std::strerror(errno)); + return result; + } + serialization_result_t read(void* begin, std::size_t length) noexcept { + serialization_result_t result; + std::size_t read = std::fread(begin, length, 1, file_); + if (length && !read) + return result.failed(std::feof(file_) ? "End of file reached!" : std::strerror(errno)); + return result; + } + void close() noexcept { + if (file_) + std::fclose(exchange(file_, nullptr)); + } + + explicit operator bool() const noexcept { return file_; } + bool seek_to(std::size_t progress) noexcept { + return std::fseek(file_, static_cast(progress), SEEK_SET) == 0; + } + bool seek_to_end() noexcept { return std::fseek(file_, 0L, SEEK_END) == 0; } + bool infer_progress(std::size_t& progress) noexcept { + long int result = std::ftell(file_); + if (result == -1L) + return false; + progress = static_cast(result); + return true; + } }; /** -* @brief Represents a memory-mapped file or a pre-allocated anonymous memory region. -* -* This class provides a convenient way to memory-map a file and access its contents as a block of -* memory. The class handles platform-specific memory-mapping operations on Windows, Linux, and MacOS. -* The class automatically closes the file when the object is destroyed. -*/ + * @brief Represents a memory-mapped file or a pre-allocated anonymous memory region. + * + * This class provides a convenient way to memory-map a file and access its contents as a block of + * memory. The class handles platform-specific memory-mapping operations on Windows, Linux, and MacOS. + * The class automatically closes the file when the object is destroyed. + */ class memory_mapped_file_t { - char const* path_{}; /**< The path to the file to be memory-mapped. */ - void* ptr_{}; /**< A pointer to the memory-mapping. */ - size_t length_{}; /**< The length of the memory-mapped file in bytes. */ + char const* path_{}; /**< The path to the file to be memory-mapped. */ + void* ptr_{}; /**< A pointer to the memory-mapping. */ + size_t length_{}; /**< The length of the memory-mapped file in bytes. */ #if defined(USEARCH_DEFINED_WINDOWS) - HANDLE file_handle_{}; /**< The file handle on Windows. */ - HANDLE mapping_handle_{}; /**< The mapping handle on Windows. */ + HANDLE file_handle_{}; /**< The file handle on Windows. */ + HANDLE mapping_handle_{}; /**< The mapping handle on Windows. */ #else - int file_descriptor_{}; /**< The file descriptor on Linux and MacOS. */ + int file_descriptor_{}; /**< The file descriptor on Linux and MacOS. */ #endif -public: - explicit operator bool() const noexcept { return ptr_ != nullptr; } - byte_t* data() noexcept { return reinterpret_cast(ptr_); } - byte_t const* data() const noexcept { return reinterpret_cast(ptr_); } - std::size_t size() const noexcept { return static_cast(length_); } - - memory_mapped_file_t() noexcept {} - memory_mapped_file_t(char const* path) noexcept : path_(path) {} - ~memory_mapped_file_t() noexcept { close(); } - memory_mapped_file_t(memory_mapped_file_t&& other) noexcept - : path_(exchange(other.path_, nullptr)), ptr_(exchange(other.ptr_, nullptr)), - length_(exchange(other.length_, 0)), + public: + explicit operator bool() const noexcept { return ptr_ != nullptr; } + byte_t* data() noexcept { return reinterpret_cast(ptr_); } + byte_t const* data() const noexcept { return reinterpret_cast(ptr_); } + std::size_t size() const noexcept { return static_cast(length_); } + + memory_mapped_file_t() noexcept {} + memory_mapped_file_t(char const* path) noexcept : path_(path) {} + ~memory_mapped_file_t() noexcept { close(); } + memory_mapped_file_t(memory_mapped_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), ptr_(exchange(other.ptr_, nullptr)), + length_(exchange(other.length_, 0)), #if defined(USEARCH_DEFINED_WINDOWS) - file_handle_(exchange(other.file_handle_, nullptr)), mapping_handle_(exchange(other.mapping_handle_, nullptr)) + file_handle_(exchange(other.file_handle_, nullptr)), mapping_handle_(exchange(other.mapping_handle_, nullptr)) #else - file_descriptor_(exchange(other.file_descriptor_, 0)) + file_descriptor_(exchange(other.file_descriptor_, 0)) #endif - { - } + { + } - memory_mapped_file_t(byte_t* data, std::size_t length) noexcept : ptr_(data), length_(length) {} + memory_mapped_file_t(byte_t* data, std::size_t length) noexcept : ptr_(data), length_(length) {} - memory_mapped_file_t& operator=(memory_mapped_file_t&& other) noexcept { - std::swap(path_, other.path_); - std::swap(ptr_, other.ptr_); - std::swap(length_, other.length_); + memory_mapped_file_t& operator=(memory_mapped_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(ptr_, other.ptr_); + std::swap(length_, other.length_); #if defined(USEARCH_DEFINED_WINDOWS) - std::swap(file_handle_, other.file_handle_); - std::swap(mapping_handle_, other.mapping_handle_); + std::swap(file_handle_, other.file_handle_); + std::swap(mapping_handle_, other.mapping_handle_); #else - std::swap(file_descriptor_, other.file_descriptor_); + std::swap(file_descriptor_, other.file_descriptor_); #endif - return *this; - } + return *this; + } - serialization_result_t open_if_not() noexcept { - serialization_result_t result; - if (!path_ || ptr_) - return result; + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!path_ || ptr_) + return result; #if defined(USEARCH_DEFINED_WINDOWS) - HANDLE file_handle = - CreateFile(path_, GENERIC_READ, FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); - if (file_handle == INVALID_HANDLE_VALUE) - return result.failed("Opening file failed!"); - - std::size_t file_length = GetFileSize(file_handle, 0); - HANDLE mapping_handle = CreateFileMapping(file_handle, 0, PAGE_READONLY, 0, 0, 0); - if (mapping_handle == 0) { - CloseHandle(file_handle); - return result.failed("Mapping file failed!"); - } - - byte_t* file = (byte_t*)MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_length); - if (file == 0) { - CloseHandle(mapping_handle); - CloseHandle(file_handle); - return result.failed("View the map failed!"); - } - file_handle_ = file_handle; - mapping_handle_ = mapping_handle; - ptr_ = file; - length_ = file_length; + HANDLE file_handle = + CreateFile(path_, GENERIC_READ, FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); + if (file_handle == INVALID_HANDLE_VALUE) + return result.failed("Opening file failed!"); + + std::size_t file_length = GetFileSize(file_handle, 0); + HANDLE mapping_handle = CreateFileMapping(file_handle, 0, PAGE_READONLY, 0, 0, 0); + if (mapping_handle == 0) { + CloseHandle(file_handle); + return result.failed("Mapping file failed!"); + } + + byte_t* file = (byte_t*)MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_length); + if (file == 0) { + CloseHandle(mapping_handle); + CloseHandle(file_handle); + return result.failed("View the map failed!"); + } + file_handle_ = file_handle; + mapping_handle_ = mapping_handle; + ptr_ = file; + length_ = file_length; #else #if defined(USEARCH_DEFINED_LINUX) - int descriptor = open(path_, O_RDONLY | O_NOATIME); + int descriptor = open(path_, O_RDONLY | O_NOATIME); #else - int descriptor = open(path_, O_RDONLY); + int descriptor = open(path_, O_RDONLY); #endif - if (descriptor < 0) - return result.failed(std::strerror(errno)); - - // Estimate the file size - struct stat file_stat; - int fstat_status = fstat(descriptor, &file_stat); - if (fstat_status < 0) { - ::close(descriptor); - return result.failed(std::strerror(errno)); - } - - // Map the entire file - byte_t* file = (byte_t*)mmap(NULL, file_stat.st_size, PROT_READ, MAP_SHARED, descriptor, 0); - if (file == MAP_FAILED) { - ::close(descriptor); - return result.failed(std::strerror(errno)); - } - file_descriptor_ = descriptor; - ptr_ = file; - length_ = file_stat.st_size; + if (descriptor < 0) + return result.failed(std::strerror(errno)); + + // Estimate the file size + struct stat file_stat; + int fstat_status = fstat(descriptor, &file_stat); + if (fstat_status < 0) { + ::close(descriptor); + return result.failed(std::strerror(errno)); + } + + // Map the entire file + byte_t* file = (byte_t*)mmap(NULL, file_stat.st_size, PROT_READ, MAP_SHARED, descriptor, 0); + if (file == MAP_FAILED) { + ::close(descriptor); + return result.failed(std::strerror(errno)); + } + file_descriptor_ = descriptor; + ptr_ = file; + length_ = file_stat.st_size; #endif // Platform specific code - return result; - } - - void close() noexcept { - if (!path_) { - ptr_ = nullptr; - length_ = 0; - return; - } + return result; + } + + void close() noexcept { + if (!path_) { + ptr_ = nullptr; + length_ = 0; + return; + } #if defined(USEARCH_DEFINED_WINDOWS) - UnmapViewOfFile(ptr_); - CloseHandle(mapping_handle_); - CloseHandle(file_handle_); - mapping_handle_ = nullptr; - file_handle_ = nullptr; + UnmapViewOfFile(ptr_); + CloseHandle(mapping_handle_); + CloseHandle(file_handle_); + mapping_handle_ = nullptr; + file_handle_ = nullptr; #else - munmap(ptr_, length_); - ::close(file_descriptor_); - file_descriptor_ = 0; + munmap(ptr_, length_); + ::close(file_descriptor_); + file_descriptor_ = 0; #endif - ptr_ = nullptr; - length_ = 0; - } + ptr_ = nullptr; + length_ = 0; + } }; struct index_serialized_header_t { - std::uint64_t size = 0; - std::uint64_t connectivity = 0; - std::uint64_t connectivity_base = 0; - std::uint64_t max_level = 0; - std::uint64_t entry_slot = 0; + std::uint64_t size = 0; + std::uint64_t connectivity = 0; + std::uint64_t connectivity_base = 0; + std::uint64_t max_level = 0; + std::uint64_t entry_slot = 0; }; using default_key_t = std::uint64_t; @@ -1574,2263 +1588,2280 @@ using default_slot_t = std::uint32_t; using default_distance_t = float; template struct member_gt { - key_at key; - std::size_t slot; + key_at key; + std::size_t slot; }; template inline std::size_t get_slot(member_gt const& m) noexcept { return m.slot; } template inline key_at get_key(member_gt const& m) noexcept { return m.key; } template struct member_cref_gt { - misaligned_ref_gt key; - std::size_t slot; + misaligned_ref_gt key; + std::size_t slot; }; template inline std::size_t get_slot(member_cref_gt const& m) noexcept { return m.slot; } template inline key_at get_key(member_cref_gt const& m) noexcept { return m.key; } template struct member_ref_gt { - misaligned_ref_gt key; - std::size_t slot; + misaligned_ref_gt key; + std::size_t slot; - inline operator member_cref_gt() const noexcept { return {key.ptr(), slot}; } + inline operator member_cref_gt() const noexcept { return {key.ptr(), slot}; } }; template inline std::size_t get_slot(member_ref_gt const& m) noexcept { return m.slot; } template inline key_at get_key(member_ref_gt const& m) noexcept { return m.key; } /** -* @brief Approximate Nearest Neighbors Search @b index-structure using the -* Hierarchical Navigable Small World @b (HNSW) graphs algorithm. -* If classical containers store @b Key->Value mappings, this one can -* be seen as a network of keys, accelerating approximate @b Value~>Key visited_members. -* -* Unlike most implementations, this one is generic anc can be used for any search, -* not just within equi-dimensional vectors. Examples range from texts to similar Chess -* positions. -* -* @tparam key_at -* The type of primary objects stored in the index. -* The values, to which those map, are not managed by the same index structure. -* -* @tparam compressed_slot_at -* The smallest unsigned integer type to address indexed elements. -* It is used internally to maximize space-efficiency and is generally -* up-casted to @b `std::size_t` in public interfaces. -* Can be a built-in @b `uint32_t`, `uint64_t`, or our custom @b `uint40_t`. -* Which makes the most sense for 4B+ entry indexes. -* -* @tparam dynamic_allocator_at -* Dynamic memory allocator for temporary buffers, visits indicators, and -* priority queues, needed during construction and traversals of graphs. -* The allocated buffers may be uninitialized. -* -* @tparam tape_allocator_at -* Potentially different memory allocator for primary allocations of nodes and vectors. -* It would never `deallocate` separate entries, and would only free all the space at once. -* The allocated buffers may be uninitialized. -* -* @section Features -* -* - Thread-safe for concurrent construction, search, and updates. -* - Doesn't allocate new threads, and reuses the ones its called from. -* - Allows storing value externally, managing just the similarity index. -* - Joins. - -* @section Usage -* -* @subsection Exceptions -* -* None of the methods throw exceptions in the "Release" compilation mode. -* It may only `throw` if your memory ::dynamic_allocator_at or ::metric_at isn't -* safe to copy. -* -* @subsection Serialization -* -* When serialized, doesn't include any additional metadata. -* It is just the multi-level proximity-graph. You may want to store metadata about -* the used metric and key types somewhere else. -* -* @section Implementation Details -* -* Like every HNSW implementation, USearch builds levels of "Proximity Graphs". -* Every added vector forms a node in one or more levels of the graph. -* Every node is present in the base level. Every following level contains a smaller -* fraction of nodes. During search, the operation starts with the smaller levels -* and zooms-in on every following iteration of larger graph traversals. -* -* Just one memory allocation is performed regardless of the number of levels. -* The adjacency lists across all levels are concatenated into that single buffer. -* That buffer starts with a "head", that stores the metadata, such as the -* tallest "level" of the graph that it belongs to, the external "key", and the -* number of "dimensions" in the vector. -* -* @section Metrics, Predicates and Callbacks -* -* -* @section Smart References and Iterators -* -* - `member_citerator_t` and `member_iterator_t` have only slots, no indirections. -* -* - `member_cref_t` and `member_ref_t` contains the `slot` and a reference -* to the key. So it passes through 1 level of visited_members in `nodes_`. -* Retrieving the key via `get_key` will cause fetching yet another cache line. -* -* - `member_gt` contains an already prefetched copy of the key. -* -*/ + * @brief Approximate Nearest Neighbors Search @b index-structure using the + * Hierarchical Navigable Small World @b (HNSW) graphs algorithm. + * If classical containers store @b Key->Value mappings, this one can + * be seen as a network of keys, accelerating approximate @b Value~>Key visited_members. + * + * Unlike most implementations, this one is generic anc can be used for any search, + * not just within equi-dimensional vectors. Examples range from texts to similar Chess + * positions. + * + * @tparam key_at + * The type of primary objects stored in the index. + * The values, to which those map, are not managed by the same index structure. + * + * @tparam compressed_slot_at + * The smallest unsigned integer type to address indexed elements. + * It is used internally to maximize space-efficiency and is generally + * up-casted to @b `std::size_t` in public interfaces. + * Can be a built-in @b `uint32_t`, `uint64_t`, or our custom @b `uint40_t`. + * Which makes the most sense for 4B+ entry indexes. + * + * @tparam dynamic_allocator_at + * Dynamic memory allocator for temporary buffers, visits indicators, and + * priority queues, needed during construction and traversals of graphs. + * The allocated buffers may be uninitialized. + * + * @tparam tape_allocator_at + * Potentially different memory allocator for primary allocations of nodes and vectors. + * It would never `deallocate` separate entries, and would only free all the space at once. + * The allocated buffers may be uninitialized. + * + * @section Features + * + * - Thread-safe for concurrent construction, search, and updates. + * - Doesn't allocate new threads, and reuses the ones its called from. + * - Allows storing value externally, managing just the similarity index. + * - Joins. + + * @section Usage + * + * @subsection Exceptions + * + * None of the methods throw exceptions in the "Release" compilation mode. + * It may only `throw` if your memory ::dynamic_allocator_at or ::metric_at isn't + * safe to copy. + * + * @subsection Serialization + * + * When serialized, doesn't include any additional metadata. + * It is just the multi-level proximity-graph. You may want to store metadata about + * the used metric and key types somewhere else. + * + * @section Implementation Details + * + * Like every HNSW implementation, USearch builds levels of "Proximity Graphs". + * Every added vector forms a node in one or more levels of the graph. + * Every node is present in the base level. Every following level contains a smaller + * fraction of nodes. During search, the operation starts with the smaller levels + * and zooms-in on every following iteration of larger graph traversals. + * + * Just one memory allocation is performed regardless of the number of levels. + * The adjacency lists across all levels are concatenated into that single buffer. + * That buffer starts with a "head", that stores the metadata, such as the + * tallest "level" of the graph that it belongs to, the external "key", and the + * number of "dimensions" in the vector. + * + * @section Metrics, Predicates and Callbacks + * + * + * @section Smart References and Iterators + * + * - `member_citerator_t` and `member_iterator_t` have only slots, no indirections. + * + * - `member_cref_t` and `member_ref_t` contains the `slot` and a reference + * to the key. So it passes through 1 level of visited_members in `nodes_`. + * Retrieving the key via `get_key` will cause fetching yet another cache line. + * + * - `member_gt` contains an already prefetched copy of the key. + * + */ template , // - typename tape_allocator_at = dynamic_allocator_at> // + typename key_at = default_key_t, // + typename compressed_slot_at = default_slot_t, // + typename dynamic_allocator_at = std::allocator, // + typename tape_allocator_at = dynamic_allocator_at> // class index_gt { -public: - using distance_t = distance_at; - using vector_key_t = key_at; - using key_t = vector_key_t; - using compressed_slot_t = compressed_slot_at; - using dynamic_allocator_t = dynamic_allocator_at; - using tape_allocator_t = tape_allocator_at; - static_assert(sizeof(vector_key_t) >= sizeof(compressed_slot_t), "Having tiny keys doesn't make sense."); - - using member_cref_t = member_cref_gt; - using member_ref_t = member_ref_gt; - - template class member_iterator_gt { - using ref_t = ref_at; - using index_t = index_at; - - friend class index_gt; - member_iterator_gt() noexcept {} - member_iterator_gt(index_t* index, std::size_t slot) noexcept : index_(index), slot_(slot) {} - - index_t* index_{}; - std::size_t slot_{}; - - public: - using iterator_category = std::random_access_iterator_tag; - using value_type = ref_t; - using difference_type = std::ptrdiff_t; - using pointer = void; - using reference = ref_t; - - reference operator*() const noexcept { return {index_->node_at_(slot_).key(), slot_}; } - vector_key_t key() const noexcept { return index_->node_at_(slot_).key(); } - - friend inline std::size_t get_slot(member_iterator_gt const& it) noexcept { return it.slot_; } - friend inline vector_key_t get_key(member_iterator_gt const& it) noexcept { return it.key(); } - - member_iterator_gt operator++(int) noexcept { return member_iterator_gt(index_, slot_ + 1); } - member_iterator_gt operator--(int) noexcept { return member_iterator_gt(index_, slot_ - 1); } - member_iterator_gt operator+(difference_type d) noexcept { return member_iterator_gt(index_, slot_ + d); } - member_iterator_gt operator-(difference_type d) noexcept { return member_iterator_gt(index_, slot_ - d); } - - // clang-format off - member_iterator_gt& operator++() noexcept { slot_ += 1; return *this; } - member_iterator_gt& operator--() noexcept { slot_ -= 1; return *this; } - member_iterator_gt& operator+=(difference_type d) noexcept { slot_ += d; return *this; } - member_iterator_gt& operator-=(difference_type d) noexcept { slot_ -= d; return *this; } - bool operator==(member_iterator_gt const& other) const noexcept { return index_ == other.index_ && slot_ == other.slot_; } - bool operator!=(member_iterator_gt const& other) const noexcept { return index_ != other.index_ || slot_ != other.slot_; } - // clang-format on - }; - - using member_iterator_t = member_iterator_gt; - using member_citerator_t = member_iterator_gt; - - // STL compatibility: - using value_type = vector_key_t; - using allocator_type = dynamic_allocator_t; - using size_type = std::size_t; - using difference_type = std::ptrdiff_t; - using reference = member_ref_t; - using const_reference = member_cref_t; - using pointer = void; - using const_pointer = void; - using iterator = member_iterator_t; - using const_iterator = member_citerator_t; - using reverse_iterator = std::reverse_iterator; - using reverse_const_iterator = std::reverse_iterator; - - using dynamic_allocator_traits_t = std::allocator_traits; - using byte_t = typename dynamic_allocator_t::value_type; - static_assert( // - sizeof(byte_t) == 1, // - "Primary allocator must allocate separate addressable bytes"); - - using tape_allocator_traits_t = std::allocator_traits; - static_assert( // - sizeof(typename tape_allocator_traits_t::value_type) == 1, // - "Tape allocator must allocate separate addressable bytes"); - -private: - /** - * @brief Integer for the number of node neighbors at a specific level of the - * multi-level graph. It's selected to be `std::uint32_t` to improve the - * alignment in most common cases. - */ - using neighbors_count_t = std::uint32_t; - using level_t = std::int16_t; - - /** - * @brief How many bytes of memory are needed to form the "head" of the node. - */ - static constexpr std::size_t node_head_bytes_() { return sizeof(vector_key_t) + sizeof(level_t); } - - using nodes_mutexes_t = bitset_gt; - - using visits_hash_set_t = growing_hash_set_gt, dynamic_allocator_t>; - - struct precomputed_constants_t { - double inverse_log_connectivity{}; - std::size_t neighbors_bytes{}; - std::size_t neighbors_base_bytes{}; - }; - /// @brief A space-efficient internal data-structure used in graph traversal queues. - struct candidate_t { - distance_t distance; - compressed_slot_t slot; - inline bool operator<(candidate_t other) const noexcept { return distance < other.distance; } - }; - - using candidates_view_t = span_gt; - using candidates_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - using top_candidates_t = sorted_buffer_gt, candidates_allocator_t>; - using next_candidates_t = max_heap_gt, candidates_allocator_t>; - - /** - * @brief A loosely-structured handle for every node. One such node is created for every member. - * To minimize memory usage and maximize the number of entries per cache-line, it only - * stores to pointers. The internal tape starts with a `vector_key_t` @b key, then - * a `level_t` for the number of graph @b levels in which this member appears, - * then the { `neighbors_count_t`, `compressed_slot_t`, `compressed_slot_t` ... } sequences - * for @b each-level. - */ - class node_t { - byte_t* tape_{}; - - public: - explicit node_t(byte_t* tape) noexcept : tape_(tape) {} - byte_t* tape() const noexcept { return tape_; } - byte_t* neighbors_tape() const noexcept { return tape_ + node_head_bytes_(); } - explicit operator bool() const noexcept { return tape_; } - - node_t() = default; - node_t(node_t const&) = default; - node_t& operator=(node_t const&) = default; - - misaligned_ref_gt ckey() const noexcept { return {tape_}; } - misaligned_ref_gt key() const noexcept { return {tape_}; } - misaligned_ref_gt level() const noexcept { return {tape_ + sizeof(vector_key_t)}; } - - void key(vector_key_t v) noexcept { return misaligned_store(tape_, v); } - void level(level_t v) noexcept { return misaligned_store(tape_ + sizeof(vector_key_t), v); } - }; - - static_assert(std::is_trivially_copy_constructible::value, "Nodes must be light!"); - static_assert(std::is_trivially_destructible::value, "Nodes must be light!"); - - /** - * @brief A slice of the node's tape, containing a the list of neighbors - * for a node in a single graph level. It's pre-allocated to fit - * as many neighbors "slots", as may be needed at the target level, - * and starts with a single integer `neighbors_count_t` counter. - */ - class neighbors_ref_t { - byte_t* tape_; - - static constexpr std::size_t shift(std::size_t i = 0) { - return sizeof(neighbors_count_t) + sizeof(compressed_slot_t) * i; - } - - public: - neighbors_ref_t(byte_t* tape) noexcept : tape_(tape) {} - misaligned_ptr_gt begin() noexcept { return tape_ + shift(); } - misaligned_ptr_gt end() noexcept { return begin() + size(); } - misaligned_ptr_gt begin() const noexcept { return tape_ + shift(); } - misaligned_ptr_gt end() const noexcept { return begin() + size(); } - compressed_slot_t operator[](std::size_t i) const noexcept { - return misaligned_load(tape_ + shift(i)); - } - std::size_t size() const noexcept { return misaligned_load(tape_); } - void clear() noexcept { - neighbors_count_t n = misaligned_load(tape_); - std::memset(tape_, 0, shift(n)); - // misaligned_store(tape_, 0); - } - void push_back(compressed_slot_t slot) noexcept { - neighbors_count_t n = misaligned_load(tape_); - misaligned_store(tape_ + shift(n), slot); - misaligned_store(tape_, n + 1); - } - }; - - /** - * @brief A package of all kinds of temporary data-structures, that the threads - * would reuse to process requests. Similar to having all of those as - * separate `thread_local` global variables. - */ - struct usearch_align_m context_t { - top_candidates_t top_candidates{}; - next_candidates_t next_candidates{}; - visits_hash_set_t visits{}; - std::default_random_engine level_generator{}; - std::size_t iteration_cycles{}; - std::size_t computed_distances_count{}; - - template // - inline distance_t measure(value_at const& first, entry_at const& second, metric_at&& metric) noexcept { - static_assert( // - std::is_same::value || std::is_same::value, - "Unexpected type"); - - computed_distances_count++; - return metric(first, second); - } - - template // - inline distance_t measure(entry_at const& first, entry_at const& second, metric_at&& metric) noexcept { - static_assert( // - std::is_same::value || std::is_same::value, - "Unexpected type"); - - computed_distances_count++; - return metric(first, second); - } - }; - - index_config_t config_{}; - index_limits_t limits_{}; - - mutable dynamic_allocator_t dynamic_allocator_{}; - tape_allocator_t tape_allocator_{}; - - precomputed_constants_t pre_{}; - memory_mapped_file_t viewed_file_{}; - - /// @brief Number of "slots" available for `node_t` objects. Equals to @b `limits_.members`. - usearch_align_m mutable std::atomic nodes_capacity_{}; - - /// @brief Number of "slots" already storing non-null nodes. - usearch_align_m mutable std::atomic nodes_count_{}; - - /// @brief Controls access to `max_level_` and `entry_slot_`. - /// If any thread is updating those values, no other threads can `add()` or `search()`. - std::mutex global_mutex_{}; - - /// @brief The level of the top-most graph in the index. Grows as the logarithm of size, starts from zero. - level_t max_level_{}; - - /// @brief The slot in which the only node of the top-level graph is stored. - std::size_t entry_slot_{}; - - using nodes_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - - /// @brief C-style array of `node_t` smart-pointers. - buffer_gt nodes_{}; - - /// @brief Mutex, that limits concurrent access to `nodes_`. - mutable nodes_mutexes_t nodes_mutexes_{}; - - using contexts_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - - /// @brief Array of thread-specific buffers for temporary data. - mutable buffer_gt contexts_{}; - -public: - std::size_t connectivity() const noexcept { return config_.connectivity; } - std::size_t capacity() const noexcept { return nodes_capacity_; } - std::size_t size() const noexcept { return nodes_count_; } - std::size_t max_level() const noexcept { return nodes_count_ ? static_cast(max_level_) : 0; } - index_config_t const& config() const noexcept { return config_; } - index_limits_t const& limits() const noexcept { return limits_; } - bool is_immutable() const noexcept { return bool(viewed_file_); } - - /** - * @section Exceptions - * Doesn't throw, unless the ::metric's and ::allocators's throw on copy-construction. - */ - explicit index_gt( // - index_config_t config = {}, dynamic_allocator_t dynamic_allocator = {}, - tape_allocator_t tape_allocator = {}) noexcept - : config_(config), limits_(0, 0), dynamic_allocator_(std::move(dynamic_allocator)), - tape_allocator_(std::move(tape_allocator)), pre_(precompute_(config)), nodes_count_(0u), max_level_(-1), - entry_slot_(0u), nodes_(), nodes_mutexes_(), contexts_() {} - - /** - * @brief Clones the structure with the same hyper-parameters, but without contents. - */ - index_gt fork() noexcept { return index_gt{config_, dynamic_allocator_, tape_allocator_}; } - - ~index_gt() noexcept { reset(); } - - index_gt(index_gt&& other) noexcept { swap(other); } - - index_gt& operator=(index_gt&& other) noexcept { - swap(other); - return *this; - } - - struct copy_result_t { - error_t error; - index_gt index; - - explicit operator bool() const noexcept { return !error; } - copy_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - - copy_result_t copy(index_copy_config_t config = {}) const noexcept { - copy_result_t result; - index_gt& other = result.index; - other = index_gt(config_, dynamic_allocator_, tape_allocator_); - if (!other.reserve(limits_)) - return result.failed("Failed to reserve the contexts"); - - // Now all is left - is to allocate new `node_t` instances and populate - // the `other.nodes_` array into it. - for (std::size_t i = 0; i != nodes_count_; ++i) - other.nodes_[i] = other.node_make_copy_(node_bytes_(nodes_[i])); - - other.nodes_count_ = nodes_count_.load(); - other.max_level_ = max_level_; - other.entry_slot_ = entry_slot_; - - // This controls nothing for now :) - (void)config; - return result; - } - - member_citerator_t cbegin() const noexcept { return {this, 0}; } - member_citerator_t cend() const noexcept { return {this, size()}; } - member_citerator_t begin() const noexcept { return {this, 0}; } - member_citerator_t end() const noexcept { return {this, size()}; } - member_iterator_t begin() noexcept { return {this, 0}; } - member_iterator_t end() noexcept { return {this, size()}; } - - member_ref_t at(std::size_t slot) noexcept { return {nodes_[slot].key(), slot}; } - member_cref_t at(std::size_t slot) const noexcept { return {nodes_[slot].ckey(), slot}; } - member_iterator_t iterator_at(std::size_t slot) noexcept { return {this, slot}; } - member_citerator_t citerator_at(std::size_t slot) const noexcept { return {this, slot}; } - - dynamic_allocator_t const& dynamic_allocator() const noexcept { return dynamic_allocator_; } - tape_allocator_t const& tape_allocator() const noexcept { return tape_allocator_; } + public: + using distance_t = distance_at; + using vector_key_t = key_at; + using key_t = vector_key_t; + using compressed_slot_t = compressed_slot_at; + using dynamic_allocator_t = dynamic_allocator_at; + using tape_allocator_t = tape_allocator_at; + static_assert(sizeof(vector_key_t) >= sizeof(compressed_slot_t), "Having tiny keys doesn't make sense."); + + using member_cref_t = member_cref_gt; + using member_ref_t = member_ref_gt; + + template class member_iterator_gt { + using ref_t = ref_at; + using index_t = index_at; + + friend class index_gt; + member_iterator_gt() noexcept {} + member_iterator_gt(index_t* index, std::size_t slot) noexcept : index_(index), slot_(slot) {} + + index_t* index_{}; + std::size_t slot_{}; + + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = ref_t; + using difference_type = std::ptrdiff_t; + using pointer = void; + using reference = ref_t; + + reference operator*() const noexcept { return {index_->node_at_(slot_).key(), slot_}; } + vector_key_t key() const noexcept { return index_->node_at_(slot_).key(); } + + friend inline std::size_t get_slot(member_iterator_gt const& it) noexcept { return it.slot_; } + friend inline vector_key_t get_key(member_iterator_gt const& it) noexcept { return it.key(); } + + member_iterator_gt operator++(int) noexcept { return member_iterator_gt(index_, slot_ + 1); } + member_iterator_gt operator--(int) noexcept { return member_iterator_gt(index_, slot_ - 1); } + member_iterator_gt operator+(difference_type d) noexcept { return member_iterator_gt(index_, slot_ + d); } + member_iterator_gt operator-(difference_type d) noexcept { return member_iterator_gt(index_, slot_ - d); } + + // clang-format off + member_iterator_gt& operator++() noexcept { slot_ += 1; return *this; } + member_iterator_gt& operator--() noexcept { slot_ -= 1; return *this; } + member_iterator_gt& operator+=(difference_type d) noexcept { slot_ += d; return *this; } + member_iterator_gt& operator-=(difference_type d) noexcept { slot_ -= d; return *this; } + bool operator==(member_iterator_gt const& other) const noexcept { return index_ == other.index_ && slot_ == other.slot_; } + bool operator!=(member_iterator_gt const& other) const noexcept { return index_ != other.index_ || slot_ != other.slot_; } + // clang-format on + }; + + using member_iterator_t = member_iterator_gt; + using member_citerator_t = member_iterator_gt; + + // STL compatibility: + using value_type = vector_key_t; + using allocator_type = dynamic_allocator_t; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using reference = member_ref_t; + using const_reference = member_cref_t; + using pointer = void; + using const_pointer = void; + using iterator = member_iterator_t; + using const_iterator = member_citerator_t; + using reverse_iterator = std::reverse_iterator; + using reverse_const_iterator = std::reverse_iterator; + + using dynamic_allocator_traits_t = std::allocator_traits; + using byte_t = typename dynamic_allocator_t::value_type; + static_assert( // + sizeof(byte_t) == 1, // + "Primary allocator must allocate separate addressable bytes"); + + using tape_allocator_traits_t = std::allocator_traits; + static_assert( // + sizeof(typename tape_allocator_traits_t::value_type) == 1, // + "Tape allocator must allocate separate addressable bytes"); + + private: + /** + * @brief Integer for the number of node neighbors at a specific level of the + * multi-level graph. It's selected to be `std::uint32_t` to improve the + * alignment in most common cases. + */ + using neighbors_count_t = std::uint32_t; + using level_t = std::int16_t; + + /** + * @brief How many bytes of memory are needed to form the "head" of the node. + */ + static constexpr std::size_t node_head_bytes_() { return sizeof(vector_key_t) + sizeof(level_t); } + + using nodes_mutexes_t = bitset_gt; + + using visits_hash_set_t = growing_hash_set_gt, dynamic_allocator_t>; + + struct precomputed_constants_t { + double inverse_log_connectivity{}; + std::size_t neighbors_bytes{}; + std::size_t neighbors_base_bytes{}; + }; + /// @brief A space-efficient internal data-structure used in graph traversal queues. + struct candidate_t { + distance_t distance; + compressed_slot_t slot; + inline bool operator<(candidate_t other) const noexcept { return distance < other.distance; } + }; + + using candidates_view_t = span_gt; + using candidates_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + using top_candidates_t = sorted_buffer_gt, candidates_allocator_t>; + using next_candidates_t = max_heap_gt, candidates_allocator_t>; + + /** + * @brief A loosely-structured handle for every node. One such node is created for every member. + * To minimize memory usage and maximize the number of entries per cache-line, it only + * stores to pointers. The internal tape starts with a `vector_key_t` @b key, then + * a `level_t` for the number of graph @b levels in which this member appears, + * then the { `neighbors_count_t`, `compressed_slot_t`, `compressed_slot_t` ... } sequences + * for @b each-level. + */ + class node_t { + byte_t* tape_{}; + + public: + explicit node_t(byte_t* tape) noexcept : tape_(tape) {} + byte_t* tape() const noexcept { return tape_; } + byte_t* neighbors_tape() const noexcept { return tape_ + node_head_bytes_(); } + explicit operator bool() const noexcept { return tape_; } + + node_t() = default; + node_t(node_t const&) = default; + node_t& operator=(node_t const&) = default; + + misaligned_ref_gt ckey() const noexcept { return {tape_}; } + misaligned_ref_gt key() const noexcept { return {tape_}; } + misaligned_ref_gt level() const noexcept { return {tape_ + sizeof(vector_key_t)}; } + + void key(vector_key_t v) noexcept { return misaligned_store(tape_, v); } + void level(level_t v) noexcept { return misaligned_store(tape_ + sizeof(vector_key_t), v); } + }; + + static_assert(std::is_trivially_copy_constructible::value, "Nodes must be light!"); + static_assert(std::is_trivially_destructible::value, "Nodes must be light!"); + + /** + * @brief A slice of the node's tape, containing a the list of neighbors + * for a node in a single graph level. It's pre-allocated to fit + * as many neighbors "slots", as may be needed at the target level, + * and starts with a single integer `neighbors_count_t` counter. + */ + class neighbors_ref_t { + byte_t* tape_; + + static constexpr std::size_t shift(std::size_t i = 0) { + return sizeof(neighbors_count_t) + sizeof(compressed_slot_t) * i; + } + + public: + neighbors_ref_t(byte_t* tape) noexcept : tape_(tape) {} + misaligned_ptr_gt begin() noexcept { return tape_ + shift(); } + misaligned_ptr_gt end() noexcept { return begin() + size(); } + misaligned_ptr_gt begin() const noexcept { return tape_ + shift(); } + misaligned_ptr_gt end() const noexcept { return begin() + size(); } + compressed_slot_t operator[](std::size_t i) const noexcept { + return misaligned_load(tape_ + shift(i)); + } + std::size_t size() const noexcept { return misaligned_load(tape_); } + void clear() noexcept { + neighbors_count_t n = misaligned_load(tape_); + std::memset(tape_, 0, shift(n)); + // misaligned_store(tape_, 0); + } + void push_back(compressed_slot_t slot) noexcept { + neighbors_count_t n = misaligned_load(tape_); + misaligned_store(tape_ + shift(n), slot); + misaligned_store(tape_, n + 1); + } + }; + + /** + * @brief A package of all kinds of temporary data-structures, that the threads + * would reuse to process requests. Similar to having all of those as + * separate `thread_local` global variables. + */ + struct usearch_align_m context_t { + top_candidates_t top_candidates{}; + next_candidates_t next_candidates{}; + visits_hash_set_t visits{}; + std::default_random_engine level_generator{}; + std::size_t iteration_cycles{}; + std::size_t computed_distances_count{}; + + template // + inline distance_t measure(value_at const& first, entry_at const& second, metric_at&& metric) noexcept { + static_assert( // + std::is_same::value || std::is_same::value, + "Unexpected type"); + + computed_distances_count++; + return metric(first, second); + } + + template // + inline distance_t measure(entry_at const& first, entry_at const& second, metric_at&& metric) noexcept { + static_assert( // + std::is_same::value || std::is_same::value, + "Unexpected type"); + + computed_distances_count++; + return metric(first, second); + } + }; + + index_config_t config_{}; + index_limits_t limits_{}; + + mutable dynamic_allocator_t dynamic_allocator_{}; + tape_allocator_t tape_allocator_{}; + + precomputed_constants_t pre_{}; + memory_mapped_file_t viewed_file_{}; + + /// @brief Number of "slots" available for `node_t` objects. Equals to @b `limits_.members`. + usearch_align_m mutable std::atomic nodes_capacity_{}; + + /// @brief Number of "slots" already storing non-null nodes. + usearch_align_m mutable std::atomic nodes_count_{}; + + /// @brief Controls access to `max_level_` and `entry_slot_`. + /// If any thread is updating those values, no other threads can `add()` or `search()`. + std::mutex global_mutex_{}; + + /// @brief The level of the top-most graph in the index. Grows as the logarithm of size, starts from zero. + level_t max_level_{}; + + /// @brief The slot in which the only node of the top-level graph is stored. + std::size_t entry_slot_{}; + + using nodes_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + /// @brief C-style array of `node_t` smart-pointers. + buffer_gt nodes_{}; + + /// @brief Mutex, that limits concurrent access to `nodes_`. + mutable nodes_mutexes_t nodes_mutexes_{}; + + using contexts_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + /// @brief Array of thread-specific buffers for temporary data. + mutable buffer_gt contexts_{}; + + public: + std::size_t connectivity() const noexcept { return config_.connectivity; } + std::size_t capacity() const noexcept { return nodes_capacity_; } + std::size_t size() const noexcept { return nodes_count_; } + std::size_t max_level() const noexcept { return nodes_count_ ? static_cast(max_level_) : 0; } + index_config_t const& config() const noexcept { return config_; } + index_limits_t const& limits() const noexcept { return limits_; } + bool is_immutable() const noexcept { return bool(viewed_file_); } + + /** + * @section Exceptions + * Doesn't throw, unless the ::metric's and ::allocators's throw on copy-construction. + */ + explicit index_gt( // + index_config_t config = {}, dynamic_allocator_t dynamic_allocator = {}, + tape_allocator_t tape_allocator = {}) noexcept + : config_(config), limits_(0, 0), dynamic_allocator_(std::move(dynamic_allocator)), + tape_allocator_(std::move(tape_allocator)), pre_(precompute_(config)), nodes_count_(0u), max_level_(-1), + entry_slot_(0u), nodes_(), nodes_mutexes_(), contexts_() {} + + /** + * @brief Clones the structure with the same hyper-parameters, but without contents. + */ + index_gt fork() noexcept { return index_gt{config_, dynamic_allocator_, tape_allocator_}; } + + ~index_gt() noexcept { reset(); } + + index_gt(index_gt&& other) noexcept { swap(other); } + + index_gt& operator=(index_gt&& other) noexcept { + swap(other); + return *this; + } + + struct copy_result_t { + error_t error; + index_gt index; + + explicit operator bool() const noexcept { return !error; } + copy_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + copy_result_t copy(index_copy_config_t config = {}) const noexcept { + copy_result_t result; + index_gt& other = result.index; + other = index_gt(config_, dynamic_allocator_, tape_allocator_); + if (!other.reserve(limits_)) + return result.failed("Failed to reserve the contexts"); + + // Now all is left - is to allocate new `node_t` instances and populate + // the `other.nodes_` array into it. + for (std::size_t i = 0; i != nodes_count_; ++i) + other.nodes_[i] = other.node_make_copy_(node_bytes_(nodes_[i])); + + other.nodes_count_ = nodes_count_.load(); + other.max_level_ = max_level_; + other.entry_slot_ = entry_slot_; + + // This controls nothing for now :) + (void)config; + return result; + } + + member_citerator_t cbegin() const noexcept { return {this, 0}; } + member_citerator_t cend() const noexcept { return {this, size()}; } + member_citerator_t begin() const noexcept { return {this, 0}; } + member_citerator_t end() const noexcept { return {this, size()}; } + member_iterator_t begin() noexcept { return {this, 0}; } + member_iterator_t end() noexcept { return {this, size()}; } + + member_ref_t at(std::size_t slot) noexcept { return {nodes_[slot].key(), slot}; } + member_cref_t at(std::size_t slot) const noexcept { return {nodes_[slot].ckey(), slot}; } + member_iterator_t iterator_at(std::size_t slot) noexcept { return {this, slot}; } + member_citerator_t citerator_at(std::size_t slot) const noexcept { return {this, slot}; } + + dynamic_allocator_t const& dynamic_allocator() const noexcept { return dynamic_allocator_; } + tape_allocator_t const& tape_allocator() const noexcept { return tape_allocator_; } #pragma region Adjusting Configuration - /** - * @brief Erases all the vectors from the index. - * - * Will change `size()` to zero, but will keep the same `capacity()`. - * Will keep the number of available threads/contexts the same as it was. - */ - void clear() noexcept { - if (!has_reset()) { - std::size_t n = nodes_count_; - for (std::size_t i = 0; i != n; ++i) - node_free_(i); - } else - tape_allocator_.deallocate(nullptr, 0); - nodes_count_ = 0; - max_level_ = -1; - entry_slot_ = 0u; - } - - /** - * @brief Erases all members from index, closing files, and returning RAM to OS. - * - * Will change both `size()` and `capacity()` to zero. - * Will deallocate all threads/contexts. - * If the index is memory-mapped - releases the mapping and the descriptor. - */ - void reset() noexcept { - clear(); - - nodes_ = {}; - contexts_ = {}; - nodes_mutexes_ = {}; - limits_ = index_limits_t{0, 0}; - nodes_capacity_ = 0; - viewed_file_ = memory_mapped_file_t{}; - tape_allocator_ = {}; - } - - /** - * @brief Swaps the underlying memory buffers and thread contexts. - */ - void swap(index_gt& other) noexcept { - std::swap(config_, other.config_); - std::swap(limits_, other.limits_); - std::swap(dynamic_allocator_, other.dynamic_allocator_); - std::swap(tape_allocator_, other.tape_allocator_); - std::swap(pre_, other.pre_); - std::swap(viewed_file_, other.viewed_file_); - std::swap(max_level_, other.max_level_); - std::swap(entry_slot_, other.entry_slot_); - std::swap(nodes_, other.nodes_); - std::swap(nodes_mutexes_, other.nodes_mutexes_); - std::swap(contexts_, other.contexts_); - - // Non-atomic parts. - std::size_t capacity_copy = nodes_capacity_; - std::size_t count_copy = nodes_count_; - nodes_capacity_ = other.nodes_capacity_.load(); - nodes_count_ = other.nodes_count_.load(); - other.nodes_capacity_ = capacity_copy; - other.nodes_count_ = count_copy; - } - - /** - * @brief Increases the `capacity()` of the index to allow adding more vectors. - * @return `true` on success, `false` on memory allocation errors. - */ - bool reserve(index_limits_t limits) usearch_noexcept_m { - - if (limits.threads_add <= limits_.threads_add // - && limits.threads_search <= limits_.threads_search // - && limits.members <= limits_.members) - return true; - - nodes_mutexes_t new_mutexes(limits.members); - buffer_gt new_nodes(limits.members); - buffer_gt new_contexts(limits.threads()); - if (!new_nodes || !new_contexts || !new_mutexes) - return false; - - // Move the nodes info, and deallocate previous buffers. - if (nodes_) - std::memcpy(new_nodes.data(), nodes_.data(), sizeof(node_t) * size()); - - limits_ = limits; - nodes_capacity_ = limits.members; - nodes_ = std::move(new_nodes); - contexts_ = std::move(new_contexts); - nodes_mutexes_ = std::move(new_mutexes); - return true; - } + /** + * @brief Erases all the vectors from the index. + * + * Will change `size()` to zero, but will keep the same `capacity()`. + * Will keep the number of available threads/contexts the same as it was. + */ + void clear() noexcept { + if (!has_reset()) { + std::size_t n = nodes_count_; + for (std::size_t i = 0; i != n; ++i) + node_free_(i); + } else + tape_allocator_.deallocate(nullptr, 0); + nodes_count_ = 0; + max_level_ = -1; + entry_slot_ = 0u; + } + + /** + * @brief Erases all members from index, closing files, and returning RAM to OS. + * + * Will change both `size()` and `capacity()` to zero. + * Will deallocate all threads/contexts. + * If the index is memory-mapped - releases the mapping and the descriptor. + */ + void reset() noexcept { + clear(); + + nodes_ = {}; + contexts_ = {}; + nodes_mutexes_ = {}; + limits_ = index_limits_t{0, 0}; + nodes_capacity_ = 0; + viewed_file_ = memory_mapped_file_t{}; + tape_allocator_ = {}; + } + + /** + * @brief Swaps the underlying memory buffers and thread contexts. + */ + void swap(index_gt& other) noexcept { + std::swap(config_, other.config_); + std::swap(limits_, other.limits_); + std::swap(dynamic_allocator_, other.dynamic_allocator_); + std::swap(tape_allocator_, other.tape_allocator_); + std::swap(pre_, other.pre_); + std::swap(viewed_file_, other.viewed_file_); + std::swap(max_level_, other.max_level_); + std::swap(entry_slot_, other.entry_slot_); + std::swap(nodes_, other.nodes_); + std::swap(nodes_mutexes_, other.nodes_mutexes_); + std::swap(contexts_, other.contexts_); + + // Non-atomic parts. + std::size_t capacity_copy = nodes_capacity_; + std::size_t count_copy = nodes_count_; + nodes_capacity_ = other.nodes_capacity_.load(); + nodes_count_ = other.nodes_count_.load(); + other.nodes_capacity_ = capacity_copy; + other.nodes_count_ = count_copy; + } + + /** + * @brief Increases the `capacity()` of the index to allow adding more vectors. + * @return `true` on success, `false` on memory allocation errors. + */ + bool reserve(index_limits_t limits) usearch_noexcept_m { + + if (limits.threads_add <= limits_.threads_add // + && limits.threads_search <= limits_.threads_search // + && limits.members <= limits_.members) + return true; + + nodes_mutexes_t new_mutexes(limits.members); + buffer_gt new_nodes(limits.members); + buffer_gt new_contexts(limits.threads()); + if (!new_nodes || !new_contexts || !new_mutexes) + return false; + + // Move the nodes info, and deallocate previous buffers. + if (nodes_) + std::memcpy(new_nodes.data(), nodes_.data(), sizeof(node_t) * size()); + + limits_ = limits; + nodes_capacity_ = limits.members; + nodes_ = std::move(new_nodes); + contexts_ = std::move(new_contexts); + nodes_mutexes_ = std::move(new_mutexes); + return true; + } #pragma endregion #pragma region Construction and Search - struct add_result_t { - error_t error{}; - std::size_t new_size{}; - std::size_t visited_members{}; - std::size_t computed_distances{}; - std::size_t slot{}; - - explicit operator bool() const noexcept { return !error; } - add_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - - /// @brief Describes a matched search result, augmenting `member_cref_t` - /// contents with `distance` to the query object. - struct match_t { - member_cref_t member; - distance_t distance; - - inline match_t() noexcept : member({nullptr, 0}), distance(std::numeric_limits::max()) {} - - inline match_t(member_cref_t member, distance_t distance) noexcept : member(member), distance(distance) {} - - inline match_t(match_t&& other) noexcept - : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) {} - - inline match_t(match_t const& other) noexcept - : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) {} - - inline match_t& operator=(match_t const& other) noexcept { - member.key.reset(other.member.key.ptr()); - member.slot = other.member.slot; - distance = other.distance; - return *this; - } - - inline match_t& operator=(match_t&& other) noexcept { - member.key.reset(other.member.key.ptr()); - member.slot = other.member.slot; - distance = other.distance; - return *this; - } - }; - - class search_result_t { - node_t const* nodes_{}; - top_candidates_t const* top_{}; - - friend class index_gt; - inline search_result_t(index_gt const& index, top_candidates_t& top) noexcept - : nodes_(index.nodes_), top_(&top) {} - - public: - /** @brief Number of search results found. */ - std::size_t count{}; - /** @brief Number of graph nodes traversed. */ - std::size_t visited_members{}; - /** @brief Number of times the distances were computed. */ - std::size_t computed_distances{}; - error_t error{}; - - inline search_result_t() noexcept {} - inline search_result_t(search_result_t&&) = default; - inline search_result_t& operator=(search_result_t&&) = default; - - explicit operator bool() const noexcept { return !error; } - search_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - - inline operator std::size_t() const noexcept { return count; } - inline std::size_t size() const noexcept { return count; } - inline bool empty() const noexcept { return !count; } - inline match_t operator[](std::size_t i) const noexcept { return at(i); } - inline match_t front() const noexcept { return at(0); } - inline match_t back() const noexcept { return at(count - 1); } - inline bool contains(vector_key_t key) const noexcept { - for (std::size_t i = 0; i != count; ++i) - if (at(i).member.key == key) - return true; - return false; - } - inline match_t at(std::size_t i) const noexcept { - candidate_t const* top_ordered = top_->data(); - candidate_t candidate = top_ordered[i]; - node_t node = nodes_[candidate.slot]; - return {member_cref_t{node.ckey(), candidate.slot}, candidate.distance}; - } - inline std::size_t merge_into( // - vector_key_t* keys, distance_t* distances, // - std::size_t old_count, std::size_t max_count) const noexcept { - - std::size_t merged_count = old_count; - for (std::size_t i = 0; i != count; ++i) { - match_t result = operator[](i); - distance_t* merged_end = distances + merged_count; - std::size_t offset = std::lower_bound(distances, merged_end, result.distance) - distances; - if (offset == max_count) - continue; - - std::size_t count_worse = merged_count - offset - (max_count == merged_count); - std::memmove(keys + offset + 1, keys + offset, count_worse * sizeof(vector_key_t)); - std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t)); - keys[offset] = result.member.key; - distances[offset] = result.distance; - merged_count += merged_count != max_count; - } - return merged_count; - } - inline std::size_t dump_to(vector_key_t* keys, distance_t* distances) const noexcept { - for (std::size_t i = 0; i != count; ++i) { - match_t result = operator[](i); - keys[i] = result.member.key; - distances[i] = result.distance; - } - return count; - } - inline std::size_t dump_to(vector_key_t* keys) const noexcept { - for (std::size_t i = 0; i != count; ++i) { - match_t result = operator[](i); - keys[i] = result.member.key; - } - return count; - } - }; - - struct cluster_result_t { - error_t error{}; - std::size_t visited_members{}; - std::size_t computed_distances{}; - match_t cluster{}; - - explicit operator bool() const noexcept { return !error; } - cluster_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - - /** - * @brief Inserts a new entry into the index. Thread-safe. Supports @b heterogeneous lookups. - * Expects needed capacity to be reserved ahead of time: `size() < capacity()`. - * - * @tparam metric_at - * A function responsible for computing the distance @b (dis-similarity) between two objects. - * It should be callable into distinctly different scenarios: - * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. - * - `distance_t operator() (entry_at, entry_at)` - between existing entries. - * Where any possible `entry_at` has both two interfaces: `std::size_t slot()`, `vector_key_t key()`. - * - * @param[in] key External identifier/name/descriptor for the new entry. - * @param[in] value Content that will be compared against other entries to index. - * @param[in] metric Callable object measuring distance between ::value and present objects. - * @param[in] config Configuration options for this specific operation. - * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. - */ - template < // - typename value_at, // - typename metric_at, // - typename callback_at = dummy_callback_t, // - typename prefetch_at = dummy_prefetch_t // - > - add_result_t add( // - vector_key_t key, value_at&& value, metric_at&& metric, // - index_update_config_t config = {}, // - callback_at&& callback = callback_at{}, // - prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { - - add_result_t result; - if (is_immutable()) - return result.failed("Can't add to an immutable index"); - - // Make sure we have enough local memory to perform this request - context_t& context = contexts_[config.thread]; - top_candidates_t& top = context.top_candidates; - next_candidates_t& next = context.next_candidates; - top.clear(); - next.clear(); - - // The top list needs one more slot than the connectivity of the base level - // for the heuristic, that tries to squeeze one more element into saturated list. - std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); - std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); - if (!top.reserve(top_limit)) - return result.failed("Out of memory!"); - if (!next.reserve(config.expansion)) - return result.failed("Out of memory!"); - - // Determining how much memory to allocate for the node depends on the target level - std::unique_lock new_level_lock(global_mutex_); - level_t max_level_copy = max_level_; // Copy under lock - std::size_t entry_idx_copy = entry_slot_; // Copy under lock - level_t target_level = choose_random_level_(context.level_generator); - - // Make sure we are not overflowing - std::size_t capacity = nodes_capacity_.load(); - std::size_t new_slot = nodes_count_.fetch_add(1); - if (new_slot >= capacity) { - nodes_count_.fetch_sub(1); - return result.failed("Reserve capacity ahead of insertions!"); - } - - // Allocate the neighbors - node_t node = node_make_(key, target_level); - if (!node) { - nodes_count_.fetch_sub(1); - return result.failed("Out of memory!"); - } - if (target_level <= max_level_copy) - new_level_lock.unlock(); - - nodes_[new_slot] = node; - result.new_size = new_slot + 1; - result.slot = new_slot; - callback(at(new_slot)); - node_lock_t new_lock = node_lock_(new_slot); - - // Do nothing for the first element - if (!new_slot) { - entry_slot_ = new_slot; - max_level_ = target_level; - return result; - } - - // Pull stats - result.computed_distances = context.computed_distances_count; - result.visited_members = context.iteration_cycles; - - connect_node_across_levels_( // - value, metric, prefetch, // - new_slot, entry_idx_copy, max_level_copy, target_level, // - config, context); - - // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; - result.visited_members = context.iteration_cycles - result.visited_members; - - // Updating the entry point if needed - if (target_level > max_level_copy) { - entry_slot_ = new_slot; - max_level_ = target_level; - } - return result; - } - - /** - * @brief Update an existing entry. Thread-safe. Supports @b heterogeneous lookups. - * - * @tparam metric_at - * A function responsible for computing the distance @b (dis-similarity) between two objects. - * It should be callable into distinctly different scenarios: - * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. - * - `distance_t operator() (entry_at, entry_at)` - between existing entries. - * For any possible `entry_at` following interfaces will work: - * - `std::size_t get_slot(entry_at const &)` - * - `vector_key_t get_key(entry_at const &)` - * - * @param[in] iterator Iterator pointing to an existing entry to be replaced. - * @param[in] key External identifier/name/descriptor for the entry. - * @param[in] value Content that will be compared against other entries in the index. - * @param[in] metric Callable object measuring distance between ::value and present objects. - * @param[in] config Configuration options for this specific operation. - * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. - */ - template < // - typename value_at, // - typename metric_at, // - typename callback_at = dummy_callback_t, // - typename prefetch_at = dummy_prefetch_t // - > - add_result_t update( // - member_iterator_t iterator, // - vector_key_t key, // - value_at&& value, // - metric_at&& metric, // - index_update_config_t config = {}, // - callback_at&& callback = callback_at{}, // - prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { - - usearch_assert_m(!is_immutable(), "Can't add to an immutable index"); - add_result_t result; - std::size_t old_slot = iterator.slot_; - - // Make sure we have enough local memory to perform this request - context_t& context = contexts_[config.thread]; - top_candidates_t& top = context.top_candidates; - next_candidates_t& next = context.next_candidates; - top.clear(); - next.clear(); - - // The top list needs one more slot than the connectivity of the base level - // for the heuristic, that tries to squeeze one more element into saturated list. - std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); - std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); - if (!top.reserve(top_limit)) - return result.failed("Out of memory!"); - if (!next.reserve(config.expansion)) - return result.failed("Out of memory!"); - - node_lock_t new_lock = node_lock_(old_slot); - node_t node = node_at_(old_slot); - - level_t node_level = node.level(); - span_bytes_t node_bytes = node_bytes_(node); - std::memset(node_bytes.data(), 0, node_bytes.size()); - node.level(node_level); - - // Pull stats - result.computed_distances = context.computed_distances_count; - result.visited_members = context.iteration_cycles; - - connect_node_across_levels_( // - value, metric, prefetch, // - old_slot, entry_slot_, max_level_, node_level, // - config, context); - node.key(key); - - // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; - result.visited_members = context.iteration_cycles - result.visited_members; - result.slot = old_slot; - - callback(at(old_slot)); - return result; - } - - /** - * @brief Searches for the closest elements to the given ::query. Thread-safe. - * - * @param[in] query Content that will be compared against other entries in the index. - * @param[in] wanted The upper bound for the number of results to return. - * @param[in] config Configuration options for this specific operation. - * @param[in] predicate Optional filtering predicate for `member_cref_t`. - * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. - */ - template < // - typename value_at, // - typename metric_at, // - typename predicate_at = dummy_predicate_t, // - typename prefetch_at = dummy_prefetch_t // - > - search_result_t search( // - value_at&& query, // - std::size_t wanted, // - metric_at&& metric, // - index_search_config_t config = {}, // - predicate_at&& predicate = predicate_at{}, // - prefetch_at&& prefetch = prefetch_at{}) const noexcept { - - context_t& context = contexts_[config.thread]; - top_candidates_t& top = context.top_candidates; - search_result_t result{*this, top}; - if (!nodes_count_) - return result; - - // Go down the level, tracking only the closest match - result.computed_distances = context.computed_distances_count; - result.visited_members = context.iteration_cycles; - - if (config.exact) { - if (!top.reserve(wanted)) - return result.failed("Out of memory!"); - search_exact_(query, metric, predicate, wanted, context); - } else { - next_candidates_t& next = context.next_candidates; - std::size_t expansion = (std::max)(config.expansion, wanted); - if (!next.reserve(expansion)) - return result.failed("Out of memory!"); - if (!top.reserve(expansion)) - return result.failed("Out of memory!"); - - std::size_t closest_slot = search_for_one_(query, metric, prefetch, entry_slot_, max_level_, 0, context); - - // For bottom layer we need a more optimized procedure - if (!search_to_find_in_base_(query, metric, predicate, prefetch, closest_slot, expansion, context)) - return result.failed("Out of memory!"); - } - - top.sort_ascending(); - top.shrink(wanted); - - // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; - result.visited_members = context.iteration_cycles - result.visited_members; - result.count = top.size(); - return result; - } - - /** - * @brief Identifies the closest cluster to the given ::query. Thread-safe. - * - * @param[in] query Content that will be compared against other entries in the index. - * @param[in] level The index level to target. Higher means lower resolution. - * @param[in] config Configuration options for this specific operation. - * @param[in] predicate Optional filtering predicate for `member_cref_t`. - * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. - */ - template < // - typename value_at, // - typename metric_at, // - typename predicate_at = dummy_predicate_t, // - typename prefetch_at = dummy_prefetch_t // - > - cluster_result_t cluster( // - value_at&& query, // - std::size_t level, // - metric_at&& metric, // - index_cluster_config_t config = {}, // - predicate_at&& predicate = predicate_at{}, // - prefetch_at&& prefetch = prefetch_at{}) const noexcept { - - context_t& context = contexts_[config.thread]; - cluster_result_t result; - if (!nodes_count_) - return result.failed("No clusters to identify"); - - // Go down the level, tracking only the closest match - result.computed_distances = context.computed_distances_count; - result.visited_members = context.iteration_cycles; - - next_candidates_t& next = context.next_candidates; - std::size_t expansion = config.expansion; - if (!next.reserve(expansion)) - return result.failed("Out of memory!"); - - result.cluster.member = at(search_for_one_(query, metric, prefetch, entry_slot_, max_level_, - static_cast(level - 1), context)); - result.cluster.distance = context.measure(query, result.cluster.member, metric); - - // Normalize stats - result.computed_distances = context.computed_distances_count - result.computed_distances; - result.visited_members = context.iteration_cycles - result.visited_members; - - (void)predicate; - return result; - } + struct add_result_t { + error_t error{}; + std::size_t new_size{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + std::size_t slot{}; + + explicit operator bool() const noexcept { return !error; } + add_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /// @brief Describes a matched search result, augmenting `member_cref_t` + /// contents with `distance` to the query object. + struct match_t { + member_cref_t member; + distance_t distance; + + inline match_t() noexcept : member({nullptr, 0}), distance(std::numeric_limits::max()) {} + + inline match_t(member_cref_t member, distance_t distance) noexcept : member(member), distance(distance) {} + + inline match_t(match_t&& other) noexcept + : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) {} + + inline match_t(match_t const& other) noexcept + : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) {} + + inline match_t& operator=(match_t const& other) noexcept { + member.key.reset(other.member.key.ptr()); + member.slot = other.member.slot; + distance = other.distance; + return *this; + } + + inline match_t& operator=(match_t&& other) noexcept { + member.key.reset(other.member.key.ptr()); + member.slot = other.member.slot; + distance = other.distance; + return *this; + } + }; + + class search_result_t { + node_t const* nodes_{}; + top_candidates_t const* top_{}; + + friend class index_gt; + inline search_result_t(index_gt const& index, top_candidates_t& top) noexcept + : nodes_(index.nodes_), top_(&top) {} + + public: + /** @brief Number of search results found. */ + std::size_t count{}; + /** @brief Number of graph nodes traversed. */ + std::size_t visited_members{}; + /** @brief Number of times the distances were computed. */ + std::size_t computed_distances{}; + error_t error{}; + + inline search_result_t() noexcept {} + inline search_result_t(search_result_t&&) = default; + inline search_result_t& operator=(search_result_t&&) = default; + + explicit operator bool() const noexcept { return !error; } + search_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + + inline operator std::size_t() const noexcept { return count; } + inline std::size_t size() const noexcept { return count; } + inline bool empty() const noexcept { return !count; } + inline match_t operator[](std::size_t i) const noexcept { return at(i); } + inline match_t front() const noexcept { return at(0); } + inline match_t back() const noexcept { return at(count - 1); } + inline bool contains(vector_key_t key) const noexcept { + for (std::size_t i = 0; i != count; ++i) + if (at(i).member.key == key) + return true; + return false; + } + inline match_t at(std::size_t i) const noexcept { + candidate_t const* top_ordered = top_->data(); + candidate_t candidate = top_ordered[i]; + node_t node = nodes_[candidate.slot]; + return {member_cref_t{node.ckey(), candidate.slot}, candidate.distance}; + } + inline std::size_t merge_into( // + vector_key_t* keys, distance_t* distances, // + std::size_t old_count, std::size_t max_count) const noexcept { + + std::size_t merged_count = old_count; + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + distance_t* merged_end = distances + merged_count; + std::size_t offset = std::lower_bound(distances, merged_end, result.distance) - distances; + if (offset == max_count) + continue; + + std::size_t count_worse = merged_count - offset - (max_count == merged_count); + std::memmove(keys + offset + 1, keys + offset, count_worse * sizeof(vector_key_t)); + std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t)); + keys[offset] = result.member.key; + distances[offset] = result.distance; + merged_count += merged_count != max_count; + } + return merged_count; + } + inline std::size_t dump_to(vector_key_t* keys, distance_t* distances) const noexcept { + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + keys[i] = result.member.key; + distances[i] = result.distance; + } + return count; + } + inline std::size_t dump_to(vector_key_t* keys) const noexcept { + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + keys[i] = result.member.key; + } + return count; + } + }; + + struct cluster_result_t { + error_t error{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + match_t cluster{}; + + explicit operator bool() const noexcept { return !error; } + cluster_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Inserts a new entry into the index. Thread-safe. Supports @b heterogeneous lookups. + * Expects needed capacity to be reserved ahead of time: `size() < capacity()`. + * + * @tparam metric_at + * A function responsible for computing the distance @b (dis-similarity) between two objects. + * It should be callable into distinctly different scenarios: + * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. + * - `distance_t operator() (entry_at, entry_at)` - between existing entries. + * Where any possible `entry_at` has both two interfaces: `std::size_t slot()`, `vector_key_t key()`. + * + * @param[in] key External identifier/name/descriptor for the new entry. + * @param[in] value Content that will be compared against other entries to index. + * @param[in] metric Callable object measuring distance between ::value and present objects. + * @param[in] config Configuration options for this specific operation. + * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. + */ + template < // + typename value_at, // + typename metric_at, // + typename callback_at = dummy_callback_t, // + typename prefetch_at = dummy_prefetch_t // + > + add_result_t add( // + vector_key_t key, value_at&& value, metric_at&& metric, // + index_update_config_t config = {}, // + callback_at&& callback = callback_at{}, // + prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { + + add_result_t result; + if (is_immutable()) + return result.failed("Can't add to an immutable index"); + + // Make sure we have enough local memory to perform this request + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + next_candidates_t& next = context.next_candidates; + top.clear(); + next.clear(); + + // The top list needs one more slot than the connectivity of the base level + // for the heuristic, that tries to squeeze one more element into saturated list. + std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); + std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); + if (!top.reserve(top_limit)) + return result.failed("Out of memory!"); + if (!next.reserve(config.expansion)) + return result.failed("Out of memory!"); + + // Determining how much memory to allocate for the node depends on the target level + std::unique_lock new_level_lock(global_mutex_); + level_t max_level_copy = max_level_; // Copy under lock + std::size_t entry_idx_copy = entry_slot_; // Copy under lock + level_t target_level = choose_random_level_(context.level_generator); + + // Make sure we are not overflowing + std::size_t capacity = nodes_capacity_.load(); + std::size_t new_slot = nodes_count_.fetch_add(1); + if (new_slot >= capacity) { + nodes_count_.fetch_sub(1); + return result.failed("Reserve capacity ahead of insertions!"); + } + + // Allocate the neighbors + node_t node = node_make_(key, target_level); + if (!node) { + nodes_count_.fetch_sub(1); + return result.failed("Out of memory!"); + } + if (target_level <= max_level_copy) + new_level_lock.unlock(); + + nodes_[new_slot] = node; + result.new_size = new_slot + 1; + result.slot = new_slot; + callback(at(new_slot)); + node_lock_t new_lock = node_lock_(new_slot); + + // Do nothing for the first element + if (!new_slot) { + entry_slot_ = new_slot; + max_level_ = target_level; + return result; + } + + // Pull stats + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + connect_node_across_levels_( // + value, metric, prefetch, // + new_slot, entry_idx_copy, max_level_copy, target_level, // + config, context); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + + // Updating the entry point if needed + if (target_level > max_level_copy) { + entry_slot_ = new_slot; + max_level_ = target_level; + } + return result; + } + + /** + * @brief Update an existing entry. Thread-safe. Supports @b heterogeneous lookups. + * + * @tparam metric_at + * A function responsible for computing the distance @b (dis-similarity) between two objects. + * It should be callable into distinctly different scenarios: + * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. + * - `distance_t operator() (entry_at, entry_at)` - between existing entries. + * For any possible `entry_at` following interfaces will work: + * - `std::size_t get_slot(entry_at const &)` + * - `vector_key_t get_key(entry_at const &)` + * + * @param[in] iterator Iterator pointing to an existing entry to be replaced. + * @param[in] key External identifier/name/descriptor for the entry. + * @param[in] value Content that will be compared against other entries in the index. + * @param[in] metric Callable object measuring distance between ::value and present objects. + * @param[in] config Configuration options for this specific operation. + * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. + */ + template < // + typename value_at, // + typename metric_at, // + typename callback_at = dummy_callback_t, // + typename prefetch_at = dummy_prefetch_t // + > + add_result_t update( // + member_iterator_t iterator, // + vector_key_t key, // + value_at&& value, // + metric_at&& metric, // + index_update_config_t config = {}, // + callback_at&& callback = callback_at{}, // + prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { + + // Someone is gonna fuzz this, so let's make sure we cover the basics + if (!config.expansion) + config.expansion = default_expansion_add(); + + usearch_assert_m(!is_immutable(), "Can't add to an immutable index"); + add_result_t result; + std::size_t old_slot = iterator.slot_; + + // Make sure we have enough local memory to perform this request + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + next_candidates_t& next = context.next_candidates; + top.clear(); + next.clear(); + + // The top list needs one more slot than the connectivity of the base level + // for the heuristic, that tries to squeeze one more element into saturated list. + std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); + std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); + if (!top.reserve(top_limit)) + return result.failed("Out of memory!"); + if (!next.reserve(config.expansion)) + return result.failed("Out of memory!"); + + node_lock_t new_lock = node_lock_(old_slot); + node_t node = node_at_(old_slot); + + level_t node_level = node.level(); + span_bytes_t node_bytes = node_bytes_(node); + std::memset(node_bytes.data(), 0, node_bytes.size()); + node.level(node_level); + + // Pull stats + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + connect_node_across_levels_( // + value, metric, prefetch, // + old_slot, entry_slot_, max_level_, node_level, // + config, context); + node.key(key); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + result.slot = old_slot; + + callback(at(old_slot)); + return result; + } + + /** + * @brief Searches for the closest elements to the given ::query. Thread-safe. + * + * @param[in] query Content that will be compared against other entries in the index. + * @param[in] wanted The upper bound for the number of results to return. + * @param[in] config Configuration options for this specific operation. + * @param[in] predicate Optional filtering predicate for `member_cref_t`. + * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. + */ + template < // + typename value_at, // + typename metric_at, // + typename predicate_at = dummy_predicate_t, // + typename prefetch_at = dummy_prefetch_t // + > + search_result_t search( // + value_at&& query, // + std::size_t wanted, // + metric_at&& metric, // + index_search_config_t config = {}, // + predicate_at&& predicate = predicate_at{}, // + prefetch_at&& prefetch = prefetch_at{}) const usearch_noexcept_m { + + // Someone is gonna fuzz this, so let's make sure we cover the basics + if (!wanted) + return search_result_t{}; + + // Expansion factor set to zero is equivalent to the default value + if (!config.expansion) + config.expansion = default_expansion_search(); + + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + search_result_t result{*this, top}; + if (!nodes_count_) + return result; + + // Go down the level, tracking only the closest match + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + if (config.exact) { + if (!top.reserve(wanted)) + return result.failed("Out of memory!"); + search_exact_(query, metric, predicate, wanted, context); + } else { + next_candidates_t& next = context.next_candidates; + std::size_t expansion = (std::max)(config.expansion, wanted); + usearch_assert_m(expansion > 0, "Expansion factor can't be a zero!"); + if (!next.reserve(expansion)) + return result.failed("Out of memory!"); + if (!top.reserve(expansion)) + return result.failed("Out of memory!"); + + std::size_t closest_slot = search_for_one_(query, metric, prefetch, entry_slot_, max_level_, 0, context); + + // For bottom layer we need a more optimized procedure + if (!search_to_find_in_base_(query, metric, predicate, prefetch, closest_slot, expansion, context)) + return result.failed("Out of memory!"); + } + + top.sort_ascending(); + top.shrink(wanted); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + result.count = top.size(); + return result; + } + + /** + * @brief Identifies the closest cluster to the given ::query. Thread-safe. + * + * @param[in] query Content that will be compared against other entries in the index. + * @param[in] level The index level to target. Higher means lower resolution. + * @param[in] config Configuration options for this specific operation. + * @param[in] predicate Optional filtering predicate for `member_cref_t`. + * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. + */ + template < // + typename value_at, // + typename metric_at, // + typename predicate_at = dummy_predicate_t, // + typename prefetch_at = dummy_prefetch_t // + > + cluster_result_t cluster( // + value_at&& query, // + std::size_t level, // + metric_at&& metric, // + index_cluster_config_t config = {}, // + predicate_at&& predicate = predicate_at{}, // + prefetch_at&& prefetch = prefetch_at{}) const noexcept { + + context_t& context = contexts_[config.thread]; + cluster_result_t result; + if (!nodes_count_) + return result.failed("No clusters to identify"); + + // Go down the level, tracking only the closest match + result.computed_distances = context.computed_distances_count; + result.visited_members = context.iteration_cycles; + + next_candidates_t& next = context.next_candidates; + std::size_t expansion = config.expansion; + if (!next.reserve(expansion)) + return result.failed("Out of memory!"); + + result.cluster.member = at(search_for_one_(query, metric, prefetch, entry_slot_, max_level_, + static_cast(level - 1), context)); + result.cluster.distance = context.measure(query, result.cluster.member, metric); + + // Normalize stats + result.computed_distances = context.computed_distances_count - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + + (void)predicate; + return result; + } #pragma endregion #pragma region Metadata - struct stats_t { - std::size_t nodes{}; - std::size_t edges{}; - std::size_t max_edges{}; - std::size_t allocated_bytes{}; - }; - - stats_t stats() const noexcept { - stats_t result{}; - - for (std::size_t i = 0; i != size(); ++i) { - node_t node = node_at_(i); - std::size_t max_edges = node.level() * config_.connectivity + config_.connectivity_base; - std::size_t edges = 0; - for (level_t level = 0; level <= node.level(); ++level) - edges += neighbors_(node, level).size(); - - ++result.nodes; - result.allocated_bytes += node_bytes_(node).size(); - result.edges += edges; - result.max_edges += max_edges; - } - return result; - } - - stats_t stats(std::size_t level) const noexcept { - stats_t result{}; - - std::size_t neighbors_bytes = !level ? pre_.neighbors_base_bytes : pre_.neighbors_bytes; - for (std::size_t i = 0; i != size(); ++i) { - node_t node = node_at_(i); - if (static_cast(node.level()) < level) - continue; - - ++result.nodes; - result.edges += neighbors_(node, level).size(); - result.allocated_bytes += node_head_bytes_() + neighbors_bytes; - } - - std::size_t max_edges_per_node = level ? config_.connectivity_base : config_.connectivity; - result.max_edges = result.nodes * max_edges_per_node; - return result; - } - - stats_t stats(stats_t* stats_per_level, std::size_t max_level) const noexcept { - - std::size_t head_bytes = node_head_bytes_(); - for (std::size_t i = 0; i != size(); ++i) { - node_t node = node_at_(i); - - stats_per_level[0].nodes++; - stats_per_level[0].edges += neighbors_(node, 0).size(); - stats_per_level[0].allocated_bytes += pre_.neighbors_base_bytes + head_bytes; - - level_t node_level = static_cast(node.level()); - for (level_t l = 1; l <= (std::min)(node_level, static_cast(max_level)); ++l) { - stats_per_level[l].nodes++; - stats_per_level[l].edges += neighbors_(node, l).size(); - stats_per_level[l].allocated_bytes += pre_.neighbors_bytes; - } - } - - // The `max_edges` parameter can be inferred from `nodes` - stats_per_level[0].max_edges = stats_per_level[0].nodes * config_.connectivity_base; - for (std::size_t l = 1; l <= max_level; ++l) - stats_per_level[l].max_edges = stats_per_level[l].nodes * config_.connectivity; - - // Aggregate stats across levels - stats_t result{}; - for (std::size_t l = 0; l <= max_level; ++l) - result.nodes += stats_per_level[l].nodes, // - result.edges += stats_per_level[l].edges, // - result.allocated_bytes += stats_per_level[l].allocated_bytes, // - result.max_edges += stats_per_level[l].max_edges; // - - return result; - } - - /** - * @brief A relatively accurate lower bound on the amount of memory consumed by the system. - * In practice it's error will be below 10%. - * - * @see `serialized_length` for the length of the binary serialized representation. - */ - std::size_t memory_usage(std::size_t allocator_entry_bytes = default_allocator_entry_bytes()) const noexcept { - std::size_t total = 0; - if (!viewed_file_) { - stats_t s = stats(); - total += s.allocated_bytes; - total += s.nodes * allocator_entry_bytes; - } - - // Temporary data-structures, proportional to the number of nodes: - total += limits_.members * sizeof(node_t) + allocator_entry_bytes; - - // Temporary data-structures, proportional to the number of threads: - total += limits_.threads() * sizeof(context_t) + allocator_entry_bytes * 3; - return total; - } - - std::size_t memory_usage_per_node(level_t level) const noexcept { return node_bytes_(level); } + struct stats_t { + std::size_t nodes{}; + std::size_t edges{}; + std::size_t max_edges{}; + std::size_t allocated_bytes{}; + }; + + stats_t stats() const noexcept { + stats_t result{}; + + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + std::size_t max_edges = node.level() * config_.connectivity + config_.connectivity_base; + std::size_t edges = 0; + for (level_t level = 0; level <= node.level(); ++level) + edges += neighbors_(node, level).size(); + + ++result.nodes; + result.allocated_bytes += node_bytes_(node).size(); + result.edges += edges; + result.max_edges += max_edges; + } + return result; + } + + stats_t stats(std::size_t level) const noexcept { + stats_t result{}; + + std::size_t neighbors_bytes = !level ? pre_.neighbors_base_bytes : pre_.neighbors_bytes; + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + if (static_cast(node.level()) < level) + continue; + + ++result.nodes; + result.edges += neighbors_(node, level).size(); + result.allocated_bytes += node_head_bytes_() + neighbors_bytes; + } + + std::size_t max_edges_per_node = level ? config_.connectivity_base : config_.connectivity; + result.max_edges = result.nodes * max_edges_per_node; + return result; + } + + stats_t stats(stats_t* stats_per_level, std::size_t max_level) const noexcept { + + std::size_t head_bytes = node_head_bytes_(); + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + + stats_per_level[0].nodes++; + stats_per_level[0].edges += neighbors_(node, 0).size(); + stats_per_level[0].allocated_bytes += pre_.neighbors_base_bytes + head_bytes; + + level_t node_level = static_cast(node.level()); + for (level_t l = 1; l <= (std::min)(node_level, static_cast(max_level)); ++l) { + stats_per_level[l].nodes++; + stats_per_level[l].edges += neighbors_(node, l).size(); + stats_per_level[l].allocated_bytes += pre_.neighbors_bytes; + } + } + + // The `max_edges` parameter can be inferred from `nodes` + stats_per_level[0].max_edges = stats_per_level[0].nodes * config_.connectivity_base; + for (std::size_t l = 1; l <= max_level; ++l) + stats_per_level[l].max_edges = stats_per_level[l].nodes * config_.connectivity; + + // Aggregate stats across levels + stats_t result{}; + for (std::size_t l = 0; l <= max_level; ++l) + result.nodes += stats_per_level[l].nodes, // + result.edges += stats_per_level[l].edges, // + result.allocated_bytes += stats_per_level[l].allocated_bytes, // + result.max_edges += stats_per_level[l].max_edges; // + + return result; + } + + /** + * @brief A relatively accurate lower bound on the amount of memory consumed by the system. + * In practice it's error will be below 10%. + * + * @see `serialized_length` for the length of the binary serialized representation. + */ + std::size_t memory_usage(std::size_t allocator_entry_bytes = default_allocator_entry_bytes()) const noexcept { + std::size_t total = 0; + if (!viewed_file_) { + stats_t s = stats(); + total += s.allocated_bytes; + total += s.nodes * allocator_entry_bytes; + } + + // Temporary data-structures, proportional to the number of nodes: + total += limits_.members * sizeof(node_t) + allocator_entry_bytes; + + // Temporary data-structures, proportional to the number of threads: + total += limits_.threads() * sizeof(context_t) + allocator_entry_bytes * 3; + return total; + } + + std::size_t memory_usage_per_node(level_t level) const noexcept { return node_bytes_(level); } #pragma endregion #pragma region Serialization - /** - * @brief Estimate the binary length (in bytes) of the serialized index. - */ - std::size_t serialized_length() const noexcept { - std::size_t neighbors_length = 0; - for (std::size_t i = 0; i != size(); ++i) - neighbors_length += node_bytes_(node_at_(i).level()) + sizeof(level_t); - return sizeof(index_serialized_header_t) + neighbors_length; - } - - /** - * @brief Saves serialized binary index representation to a stream. - */ - template - serialization_result_t save_to_stream(output_callback_at&& output, progress_at&& progress = {}) const noexcept { - - serialization_result_t result; - - // Export some basic metadata - index_serialized_header_t header; - header.size = nodes_count_; - header.connectivity = config_.connectivity; - header.connectivity_base = config_.connectivity_base; - header.max_level = max_level_; - header.entry_slot = entry_slot_; - if (!output(&header, sizeof(header))) - return result.failed("Failed to serialize the header into stream"); - - // Progress status - std::size_t processed = 0; - std::size_t const total = 2 * header.size; - - // Export the number of levels per node - // That is both enough to estimate the overall memory consumption, - // and to be able to estimate the offsets of every entry in the file. - for (std::size_t i = 0; i != header.size; ++i) { - node_t node = node_at_(i); - level_t level = node.level(); - if (!output(&level, sizeof(level))) - return result.failed("Failed to serialize into stream"); - if (!progress(++processed, total)) - return result.failed("Terminated by user"); - } - - // After that dump the nodes themselves - for (std::size_t i = 0; i != header.size; ++i) { - span_bytes_t node_bytes = node_bytes_(node_at_(i)); - if (!output(node_bytes.data(), node_bytes.size())) - return result.failed("Failed to serialize into stream"); - if (!progress(++processed, total)) - return result.failed("Terminated by user"); - } - - return {}; - } - - /** - * @brief Symmetric to `save_from_stream`, pulls data from a stream. - */ - template - serialization_result_t load_from_stream(input_callback_at&& input, progress_at&& progress = {}) noexcept { - - serialization_result_t result; - - // Remove previously stored objects - reset(); - - // Pull basic metadata - index_serialized_header_t header; - if (!input(&header, sizeof(header))) - return result.failed("Failed to pull the header from the stream"); - - // We are loading an empty index, no more work to do - if (!header.size) { - reset(); - return result; - } - - // Allocate some dynamic memory to read all the levels - using levels_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - buffer_gt levels(header.size); - if (!levels) - return result.failed("Out of memory"); - if (!input(levels, header.size * sizeof(level_t))) - return result.failed("Failed to pull nodes levels from the stream"); - - // Submit metadata - config_.connectivity = header.connectivity; - config_.connectivity_base = header.connectivity_base; - pre_ = precompute_(config_); - index_limits_t limits; - limits.members = header.size; - if (!reserve(limits)) { - reset(); - return result.failed("Out of memory"); - } - nodes_count_ = header.size; - max_level_ = static_cast(header.max_level); - entry_slot_ = static_cast(header.entry_slot); - - // Load the nodes - for (std::size_t i = 0; i != header.size; ++i) { - span_bytes_t node_bytes = node_malloc_(levels[i]); - if (!input(node_bytes.data(), node_bytes.size())) { - reset(); - return result.failed("Failed to pull nodes from the stream"); - } - nodes_[i] = node_t{node_bytes.data()}; - if (!progress(i + 1, header.size)) - return result.failed("Terminated by user"); - } - return {}; - } - - template - serialization_result_t save(char const* file_path, progress_at&& progress = {}) const noexcept { - return save(output_file_t(file_path), std::forward(progress)); - } - - template - serialization_result_t load(char const* file_path, progress_at&& progress = {}) noexcept { - return load(input_file_t(file_path), std::forward(progress)); - } - - /** - * @brief Saves serialized binary index representation to a file, generally on disk. - */ - template - serialization_result_t save(output_file_t file, progress_at&& progress = {}) const noexcept { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = save_to_stream( - [&](void* buffer, std::size_t length) { - io_result = file.write(buffer, length); - return !!io_result; - }, - std::forward(progress)); - - if (!stream_result) - return stream_result; - return io_result; - } - - /** - * @brief Memory-maps the serialized binary index representation from disk, - * @b without copying data into RAM, and fetching it on-demand. - */ - template - serialization_result_t save(memory_mapped_file_t file, std::size_t offset = 0, - progress_at&& progress = {}) const noexcept { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = save_to_stream( - [&](void* buffer, std::size_t length) { - if (offset + length > file.size()) - return false; - std::memcpy(file.data() + offset, buffer, length); - offset += length; - return true; - }, - std::forward(progress)); - - return stream_result; - } - - /** - * @brief Loads the serialized binary index representation from disk to RAM. - * Adjusts the configuration properties of the constructed index to - * match the settings in the file. - */ - template - serialization_result_t load(input_file_t file, progress_at&& progress = {}) noexcept { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = load_from_stream( - [&](void* buffer, std::size_t length) { - io_result = file.read(buffer, length); - return !!io_result; - }, - std::forward(progress)); - - if (!stream_result) - return stream_result; - return io_result; - } - - /** - * @brief Loads the serialized binary index representation from disk to RAM. - * Adjusts the configuration properties of the constructed index to - * match the settings in the file. - */ - template - serialization_result_t load(memory_mapped_file_t file, std::size_t offset = 0, - progress_at&& progress = {}) noexcept { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = load_from_stream( - [&](void* buffer, std::size_t length) { - if (offset + length > file.size()) - return false; - std::memcpy(buffer, file.data() + offset, length); - offset += length; - return true; - }, - std::forward(progress)); - - return stream_result; - } - - /** - * @brief Memory-maps the serialized binary index representation from disk, - * @b without copying data into RAM, and fetching it on-demand. - */ - template - serialization_result_t view(memory_mapped_file_t file, std::size_t offset = 0, - progress_at&& progress = {}) noexcept { - - // Remove previously stored objects - reset(); - - serialization_result_t result = file.open_if_not(); - if (!result) - return result; - - // Pull basic metadata - index_serialized_header_t header; - if (file.size() - offset < sizeof(header)) - return result.failed("File is corrupted and lacks a header"); - std::memcpy(&header, file.data() + offset, sizeof(header)); - - if (!header.size) { - reset(); - return result; - } - - // Precompute offsets of every node, but before that we need to update the configs - // This could have been done with `std::exclusive_scan`, but it's only available from C++17. - using offsets_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - buffer_gt offsets(header.size); - if (!offsets) - return result.failed("Out of memory"); - - config_.connectivity = header.connectivity; - config_.connectivity_base = header.connectivity_base; - pre_ = precompute_(config_); - misaligned_ptr_gt levels{(byte_t*)file.data() + offset + sizeof(header)}; - offsets[0u] = offset + sizeof(header) + sizeof(level_t) * header.size; - for (std::size_t i = 1; i < header.size; ++i) - offsets[i] = offsets[i - 1] + node_bytes_(levels[i - 1]); - - std::size_t total_bytes = offsets[header.size - 1] + node_bytes_(levels[header.size - 1]); - if (file.size() < total_bytes) { - reset(); - return result.failed("File is corrupted and can't fit all the nodes"); - } - - // Submit metadata and reserve memory - index_limits_t limits; - limits.members = header.size; - if (!reserve(limits)) { - reset(); - return result.failed("Out of memory"); - } - nodes_count_ = header.size; - max_level_ = static_cast(header.max_level); - entry_slot_ = static_cast(header.entry_slot); - - // Rapidly address all the nodes - for (std::size_t i = 0; i != header.size; ++i) { - nodes_[i] = node_t{(byte_t*)file.data() + offsets[i]}; - if (!progress(i + 1, header.size)) - return result.failed("Terminated by user"); - } - viewed_file_ = std::move(file); - return {}; - } + /** + * @brief Estimate the binary length (in bytes) of the serialized index. + */ + std::size_t serialized_length() const noexcept { + std::size_t neighbors_length = 0; + for (std::size_t i = 0; i != size(); ++i) + neighbors_length += node_bytes_(node_at_(i).level()) + sizeof(level_t); + return sizeof(index_serialized_header_t) + neighbors_length; + } + + /** + * @brief Saves serialized binary index representation to a stream. + */ + template + serialization_result_t save_to_stream(output_callback_at&& output, progress_at&& progress = {}) const noexcept { + + serialization_result_t result; + + // Export some basic metadata + index_serialized_header_t header; + header.size = nodes_count_; + header.connectivity = config_.connectivity; + header.connectivity_base = config_.connectivity_base; + header.max_level = max_level_; + header.entry_slot = entry_slot_; + if (!output(&header, sizeof(header))) + return result.failed("Failed to serialize the header into stream"); + + // Progress status + std::size_t processed = 0; + std::size_t const total = 2 * header.size; + + // Export the number of levels per node + // That is both enough to estimate the overall memory consumption, + // and to be able to estimate the offsets of every entry in the file. + for (std::size_t i = 0; i != header.size; ++i) { + node_t node = node_at_(i); + level_t level = node.level(); + if (!output(&level, sizeof(level))) + return result.failed("Failed to serialize into stream"); + if (!progress(++processed, total)) + return result.failed("Terminated by user"); + } + + // After that dump the nodes themselves + for (std::size_t i = 0; i != header.size; ++i) { + span_bytes_t node_bytes = node_bytes_(node_at_(i)); + if (!output(node_bytes.data(), node_bytes.size())) + return result.failed("Failed to serialize into stream"); + if (!progress(++processed, total)) + return result.failed("Terminated by user"); + } + + return {}; + } + + /** + * @brief Symmetric to `save_from_stream`, pulls data from a stream. + */ + template + serialization_result_t load_from_stream(input_callback_at&& input, progress_at&& progress = {}) noexcept { + + serialization_result_t result; + + // Remove previously stored objects + reset(); + + // Pull basic metadata + index_serialized_header_t header; + if (!input(&header, sizeof(header))) + return result.failed("Failed to pull the header from the stream"); + + // We are loading an empty index, no more work to do + if (!header.size) { + reset(); + return result; + } + + // Allocate some dynamic memory to read all the levels + using levels_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt levels(header.size); + if (!levels) + return result.failed("Out of memory"); + if (!input(levels, header.size * sizeof(level_t))) + return result.failed("Failed to pull nodes levels from the stream"); + + // Submit metadata + config_.connectivity = header.connectivity; + config_.connectivity_base = header.connectivity_base; + pre_ = precompute_(config_); + index_limits_t limits; + limits.members = header.size; + if (!reserve(limits)) { + reset(); + return result.failed("Out of memory"); + } + nodes_count_ = header.size; + max_level_ = static_cast(header.max_level); + entry_slot_ = static_cast(header.entry_slot); + + // Load the nodes + for (std::size_t i = 0; i != header.size; ++i) { + span_bytes_t node_bytes = node_malloc_(levels[i]); + if (!input(node_bytes.data(), node_bytes.size())) { + reset(); + return result.failed("Failed to pull nodes from the stream"); + } + nodes_[i] = node_t{node_bytes.data()}; + if (!progress(i + 1, header.size)) + return result.failed("Terminated by user"); + } + return {}; + } + + template + serialization_result_t save(char const* file_path, progress_at&& progress = {}) const noexcept { + return save(output_file_t(file_path), std::forward(progress)); + } + + template + serialization_result_t load(char const* file_path, progress_at&& progress = {}) noexcept { + return load(input_file_t(file_path), std::forward(progress)); + } + + /** + * @brief Saves serialized binary index representation to a file, generally on disk. + */ + template + serialization_result_t save(output_file_t file, progress_at&& progress = {}) const noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void* buffer, std::size_t length) { + io_result = file.write(buffer, length); + return !!io_result; + }, + std::forward(progress)); + + if (!stream_result) + return stream_result; + return io_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t save(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) const noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(file.data() + offset, buffer, length); + offset += length; + return true; + }, + std::forward(progress)); + + return stream_result; + } + + /** + * @brief Loads the serialized binary index representation from disk to RAM. + * Adjusts the configuration properties of the constructed index to + * match the settings in the file. + */ + template + serialization_result_t load(input_file_t file, progress_at&& progress = {}) noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + io_result = file.read(buffer, length); + return !!io_result; + }, + std::forward(progress)); + + if (!stream_result) + return stream_result; + return io_result; + } + + /** + * @brief Loads the serialized binary index representation from disk to RAM. + * Adjusts the configuration properties of the constructed index to + * match the settings in the file. + */ + template + serialization_result_t load(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(buffer, file.data() + offset, length); + offset += length; + return true; + }, + std::forward(progress)); + + return stream_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t view(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) noexcept { + + // Remove previously stored objects + reset(); + + serialization_result_t result = file.open_if_not(); + if (!result) + return result; + + // Pull basic metadata + index_serialized_header_t header; + if (file.size() - offset < sizeof(header)) + return result.failed("File is corrupted and lacks a header"); + std::memcpy(&header, file.data() + offset, sizeof(header)); + + if (!header.size) { + reset(); + return result; + } + + // Precompute offsets of every node, but before that we need to update the configs + // This could have been done with `std::exclusive_scan`, but it's only available from C++17. + using offsets_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt offsets(header.size); + if (!offsets) + return result.failed("Out of memory"); + + config_.connectivity = header.connectivity; + config_.connectivity_base = header.connectivity_base; + pre_ = precompute_(config_); + misaligned_ptr_gt levels{(byte_t*)file.data() + offset + sizeof(header)}; + offsets[0u] = offset + sizeof(header) + sizeof(level_t) * header.size; + for (std::size_t i = 1; i < header.size; ++i) + offsets[i] = offsets[i - 1] + node_bytes_(levels[i - 1]); + + std::size_t total_bytes = offsets[header.size - 1] + node_bytes_(levels[header.size - 1]); + if (file.size() < total_bytes) { + reset(); + return result.failed("File is corrupted and can't fit all the nodes"); + } + + // Submit metadata and reserve memory + index_limits_t limits; + limits.members = header.size; + if (!reserve(limits)) { + reset(); + return result.failed("Out of memory"); + } + nodes_count_ = header.size; + max_level_ = static_cast(header.max_level); + entry_slot_ = static_cast(header.entry_slot); + + // Rapidly address all the nodes + for (std::size_t i = 0; i != header.size; ++i) { + nodes_[i] = node_t{(byte_t*)file.data() + offsets[i]}; + if (!progress(i + 1, header.size)) + return result.failed("Terminated by user"); + } + viewed_file_ = std::move(file); + return {}; + } #pragma endregion - /** - * @brief Performs compaction on the whole HNSW index, purging some entries - * and links to them, while also generating a more efficient mapping, - * putting the more frequently used entries closer together. - * - * - * Scans the whole collection, removing the links leading towards - * banned entries. This essentially isolates some nodes from the rest - * of the graph, while keeping their outgoing links, in case the node - * is structurally relevant and has a crucial role in the index. - * It won't reclaim the memory. - * - * @param[in] allow_member Predicate to mark nodes for isolation. - * @param[in] executor Thread-pool to execute the job in parallel. - * @param[in] progress Callback to report the execution progress. - */ - template - void compact( // - values_at&& values, // - metric_at&& metric, // - slot_transition_at&& slot_transition, // - - executor_at&& executor = executor_at{}, // - progress_at&& progress = progress_at{}, // - prefetch_at&& prefetch = prefetch_at{}) noexcept { - - // Export all the keys, slots, and levels. - // Partition them with the predicate. - // Sort the allowed entries in descending order of their level. - // Create a new array mapping old slots to the new ones (INT_MAX for deleted items). - struct slot_level_t { - compressed_slot_t old_slot; - compressed_slot_t cluster; - level_t level; - }; - using slot_level_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - buffer_gt slots_and_levels(size()); - - // Progress status - std::atomic do_tasks{true}; - std::atomic processed{0}; - std::size_t const total = 3 * slots_and_levels.size(); - - // For every bottom level node, determine its parent cluster - executor.dynamic(slots_and_levels.size(), [&](std::size_t thread_idx, std::size_t old_slot) { - context_t& context = contexts_[thread_idx]; - std::size_t cluster = search_for_one_( // - values[citerator_at(old_slot)], // - metric, prefetch, // - entry_slot_, max_level_, 0, context); - slots_and_levels[old_slot] = { // - static_cast(old_slot), // - static_cast(cluster), // - node_at_(old_slot).level()}; - ++processed; - if (thread_idx == 0) - do_tasks = progress(processed.load(), total); - return do_tasks.load(); - }); - if (!do_tasks.load()) - return; - - // Where the actual permutation happens: - std::sort(slots_and_levels.begin(), slots_and_levels.end(), [](slot_level_t const& a, slot_level_t const& b) { - return a.level == b.level ? a.cluster < b.cluster : a.level > b.level; - }); - - using size_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - buffer_gt old_slot_to_new(slots_and_levels.size()); - for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) - old_slot_to_new[slots_and_levels[new_slot].old_slot] = new_slot; - - // Erase all the incoming links - buffer_gt reordered_nodes(slots_and_levels.size()); - tape_allocator_t reordered_tape; - - for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { - std::size_t old_slot = slots_and_levels[new_slot].old_slot; - node_t old_node = node_at_(old_slot); - - std::size_t node_bytes = node_bytes_(old_node.level()); - byte_t* new_data = (byte_t*)reordered_tape.allocate(node_bytes); - node_t new_node{new_data}; - std::memcpy(new_data, old_node.tape(), node_bytes); - - for (level_t level = 0; level <= old_node.level(); ++level) - for (misaligned_ref_gt neighbor : neighbors_(new_node, level)) - neighbor = static_cast(old_slot_to_new[compressed_slot_t(neighbor)]); - - reordered_nodes[new_slot] = new_node; - if (!progress(++processed, total)) - return; - } - - for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { - std::size_t old_slot = slots_and_levels[new_slot].old_slot; - slot_transition(node_at_(old_slot).ckey(), // - static_cast(old_slot), // - static_cast(new_slot)); - if (!progress(++processed, total)) - return; - } - - nodes_ = std::move(reordered_nodes); - tape_allocator_ = std::move(reordered_tape); - entry_slot_ = old_slot_to_new[entry_slot_]; - } - - /** - * @brief Scans the whole collection, removing the links leading towards - * banned entries. This essentially isolates some nodes from the rest - * of the graph, while keeping their outgoing links, in case the node - * is structurally relevant and has a crucial role in the index. - * It won't reclaim the memory. - * - * @param[in] allow_member Predicate to mark nodes for isolation. - * @param[in] executor Thread-pool to execute the job in parallel. - * @param[in] progress Callback to report the execution progress. - */ - template < // - typename allow_member_at = dummy_predicate_t, // - typename executor_at = dummy_executor_t, // - typename progress_at = dummy_progress_t // - > - void isolate( // - allow_member_at&& allow_member, // - executor_at&& executor = executor_at{}, // - progress_at&& progress = progress_at{}) noexcept { - - // Progress status - std::atomic do_tasks{true}; - std::atomic processed{0}; - - // Erase all the incoming links - std::size_t nodes_count = size(); - executor.dynamic(nodes_count, [&](std::size_t thread_idx, std::size_t node_idx) { - node_t node = node_at_(node_idx); - for (level_t level = 0; level <= node.level(); ++level) { - neighbors_ref_t neighbors = neighbors_(node, level); - std::size_t old_size = neighbors.size(); - neighbors.clear(); - for (std::size_t i = 0; i != old_size; ++i) { - compressed_slot_t neighbor_slot = neighbors[i]; - node_t neighbor = node_at_(neighbor_slot); - if (allow_member(member_cref_t{neighbor.ckey(), neighbor_slot})) - neighbors.push_back(neighbor_slot); - } - } - ++processed; - if (thread_idx == 0) - do_tasks = progress(processed.load(), nodes_count); - return do_tasks.load(); - }); - - // At the end report the latest numbers, because the reporter thread may be finished earlier - progress(processed.load(), nodes_count); - } - -private: - inline static precomputed_constants_t precompute_(index_config_t const& config) noexcept { - precomputed_constants_t pre; - pre.inverse_log_connectivity = 1.0 / std::log(static_cast(config.connectivity)); - pre.neighbors_bytes = config.connectivity * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); - pre.neighbors_base_bytes = config.connectivity_base * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); - return pre; - } - - using span_bytes_t = span_gt; - - inline span_bytes_t node_bytes_(node_t node) const noexcept { return {node.tape(), node_bytes_(node.level())}; } - inline std::size_t node_bytes_(level_t level) const noexcept { - return node_head_bytes_() + node_neighbors_bytes_(level); - } - inline std::size_t node_neighbors_bytes_(node_t node) const noexcept { return node_neighbors_bytes_(node.level()); } - inline std::size_t node_neighbors_bytes_(level_t level) const noexcept { - return pre_.neighbors_base_bytes + pre_.neighbors_bytes * level; - } - - span_bytes_t node_malloc_(level_t level) noexcept { - std::size_t node_bytes = node_bytes_(level); - byte_t* data = (byte_t*)tape_allocator_.allocate(node_bytes); - return data ? span_bytes_t{data, node_bytes} : span_bytes_t{}; - } - - node_t node_make_(vector_key_t key, level_t level) noexcept { - span_bytes_t node_bytes = node_malloc_(level); - if (!node_bytes) - return {}; - - std::memset(node_bytes.data(), 0, node_bytes.size()); - node_t node{(byte_t*)node_bytes.data()}; - node.key(key); - node.level(level); - return node; - } - - node_t node_make_copy_(span_bytes_t old_bytes) noexcept { - byte_t* data = (byte_t*)tape_allocator_.allocate(old_bytes.size()); - if (!data) - return {}; - std::memcpy(data, old_bytes.data(), old_bytes.size()); - return node_t{data}; - } - - void node_free_(std::size_t idx) noexcept { - if (viewed_file_) - return; - - node_t& node = nodes_[idx]; - tape_allocator_.deallocate(node.tape(), node_bytes_(node).size()); - node = node_t{}; - } - - inline node_t node_at_(std::size_t idx) const noexcept { return nodes_[idx]; } - inline neighbors_ref_t neighbors_base_(node_t node) const noexcept { return {node.neighbors_tape()}; } - - inline neighbors_ref_t neighbors_non_base_(node_t node, level_t level) const noexcept { - return {node.neighbors_tape() + pre_.neighbors_base_bytes + (level - 1) * pre_.neighbors_bytes}; - } - - inline neighbors_ref_t neighbors_(node_t node, level_t level) const noexcept { - return level ? neighbors_non_base_(node, level) : neighbors_base_(node); - } - - struct node_lock_t { - nodes_mutexes_t& mutexes; - std::size_t slot; - inline ~node_lock_t() noexcept { mutexes.atomic_reset(slot); } - }; - - inline node_lock_t node_lock_(std::size_t slot) const noexcept { - while (nodes_mutexes_.atomic_set(slot)) - ; - return {nodes_mutexes_, slot}; - } - - template - void connect_node_across_levels_( // - value_at&& value, metric_at&& metric, prefetch_at&& prefetch, // - std::size_t node_slot, std::size_t entry_slot, level_t max_level, level_t target_level, // - index_update_config_t const& config, context_t& context) usearch_noexcept_m { - - // Go down the level, tracking only the closest match - std::size_t closest_slot = search_for_one_( // - value, metric, prefetch, // - entry_slot, max_level, target_level, context); - - // From `target_level` down perform proper extensive search - for (level_t level = (std::min)(target_level, max_level); level >= 0; --level) { - // TODO: Handle out of memory conditions - search_to_insert_(value, metric, prefetch, closest_slot, node_slot, level, config.expansion, context); - closest_slot = connect_new_node_(metric, node_slot, level, context); - reconnect_neighbor_nodes_(metric, node_slot, value, level, context); - } - } - - template - std::size_t connect_new_node_( // - metric_at&& metric, std::size_t new_slot, level_t level, context_t& context) usearch_noexcept_m { - - node_t new_node = node_at_(new_slot); - top_candidates_t& top = context.top_candidates; - - // Outgoing links from `new_slot`: - neighbors_ref_t new_neighbors = neighbors_(new_node, level); - { - usearch_assert_m(!new_neighbors.size(), "The newly inserted element should have blank link list"); - candidates_view_t top_view = refine_(metric, config_.connectivity, top, context); - - for (std::size_t idx = 0; idx != top_view.size(); idx++) { - usearch_assert_m(!new_neighbors[idx], "Possible memory corruption"); - usearch_assert_m(level <= node_at_(top_view[idx].slot).level(), "Linking to missing level"); - new_neighbors.push_back(top_view[idx].slot); - } - } - - return new_neighbors[0]; - } - - template - void reconnect_neighbor_nodes_( // - metric_at&& metric, std::size_t new_slot, value_at&& value, level_t level, - context_t& context) usearch_noexcept_m { - - node_t new_node = node_at_(new_slot); - top_candidates_t& top = context.top_candidates; - neighbors_ref_t new_neighbors = neighbors_(new_node, level); - - // Reverse links from the neighbors: - std::size_t const connectivity_max = level ? config_.connectivity : config_.connectivity_base; - for (compressed_slot_t close_slot : new_neighbors) { - if (close_slot == new_slot) - continue; - node_lock_t close_lock = node_lock_(close_slot); - node_t close_node = node_at_(close_slot); - - neighbors_ref_t close_header = neighbors_(close_node, level); - usearch_assert_m(close_header.size() <= connectivity_max, "Possible corruption"); - usearch_assert_m(close_slot != new_slot, "Self-loops are impossible"); - usearch_assert_m(level <= close_node.level(), "Linking to missing level"); - - // If `new_slot` is already present in the neighboring connections of `close_slot` - // then no need to modify any connections or run the heuristics. - if (close_header.size() < connectivity_max) { - close_header.push_back(static_cast(new_slot)); - continue; - } - - // To fit a new connection we need to drop an existing one. - top.clear(); - usearch_assert_m((top.reserve(close_header.size() + 1)), "The memory must have been reserved in `add`"); - top.insert_reserved( - {context.measure(value, citerator_at(close_slot), metric), static_cast(new_slot)}); - for (compressed_slot_t successor_slot : close_header) - top.insert_reserved( - {context.measure(citerator_at(close_slot), citerator_at(successor_slot), metric), successor_slot}); - - // Export the results: - close_header.clear(); - candidates_view_t top_view = refine_(metric, connectivity_max, top, context); - for (std::size_t idx = 0; idx != top_view.size(); idx++) - close_header.push_back(top_view[idx].slot); - } - } - - level_t choose_random_level_(std::default_random_engine& level_generator) const noexcept { - std::uniform_real_distribution distribution(0.0, 1.0); - double r = -std::log(distribution(level_generator)) * pre_.inverse_log_connectivity; - return (level_t)r; - } - - struct candidates_range_t; - class candidates_iterator_t { - friend struct candidates_range_t; - - index_gt const& index_; - neighbors_ref_t neighbors_; - visits_hash_set_t& visits_; - std::size_t current_; - - candidates_iterator_t& skip_missing() noexcept { - if (!visits_.size()) - return *this; - while (current_ != neighbors_.size()) { - compressed_slot_t neighbor_slot = neighbors_[current_]; - if (visits_.test(neighbor_slot)) - current_++; - else - break; - } - return *this; - } - - public: - using element_t = compressed_slot_t; - using iterator_category = std::forward_iterator_tag; - using value_type = element_t; - using difference_type = std::ptrdiff_t; - using pointer = misaligned_ptr_gt; - using reference = misaligned_ref_gt; - - reference operator*() const noexcept { return slot(); } - candidates_iterator_t(index_gt const& index, neighbors_ref_t neighbors, visits_hash_set_t& visits, - std::size_t progress) noexcept - : index_(index), neighbors_(neighbors), visits_(visits), current_(progress) {} - candidates_iterator_t operator++(int) noexcept { - return candidates_iterator_t(index_, visits_, neighbors_, current_ + 1).skip_missing(); - } - candidates_iterator_t& operator++() noexcept { - ++current_; - skip_missing(); - return *this; - } - bool operator==(candidates_iterator_t const& other) noexcept { return current_ == other.current_; } - bool operator!=(candidates_iterator_t const& other) noexcept { return current_ != other.current_; } - - vector_key_t key() const noexcept { return index_->node_at_(slot()).key(); } - compressed_slot_t slot() const noexcept { return neighbors_[current_]; } - friend inline std::size_t get_slot(candidates_iterator_t const& it) noexcept { return it.slot(); } - friend inline vector_key_t get_key(candidates_iterator_t const& it) noexcept { return it.key(); } - }; - - struct candidates_range_t { - index_gt const& index; - neighbors_ref_t neighbors; - visits_hash_set_t& visits; - - candidates_iterator_t begin() const noexcept { - return candidates_iterator_t{index, neighbors, visits, 0}.skip_missing(); - } - candidates_iterator_t end() const noexcept { return {index, neighbors, visits, neighbors.size()}; } - }; - - template - std::size_t search_for_one_( // - value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // - std::size_t closest_slot, level_t begin_level, level_t end_level, context_t& context) const noexcept { - - visits_hash_set_t& visits = context.visits; - visits.clear(); - - // Optional prefetching - if (!is_dummy()) - prefetch(citerator_at(closest_slot), citerator_at(closest_slot + 1)); - - distance_t closest_dist = context.measure(query, citerator_at(closest_slot), metric); - for (level_t level = begin_level; level > end_level; --level) { - bool changed; - do { - changed = false; - node_lock_t closest_lock = node_lock_(closest_slot); - neighbors_ref_t closest_neighbors = neighbors_non_base_(node_at_(closest_slot), level); - - // Optional prefetching - if (!is_dummy()) { - candidates_range_t missing_candidates{*this, closest_neighbors, visits}; - prefetch(missing_candidates.begin(), missing_candidates.end()); - } - - // Actual traversal - for (compressed_slot_t candidate_slot : closest_neighbors) { - distance_t candidate_dist = context.measure(query, citerator_at(candidate_slot), metric); - if (candidate_dist < closest_dist) { - closest_dist = candidate_dist; - closest_slot = candidate_slot; - changed = true; - } - } - context.iteration_cycles++; - } while (changed); - } - return closest_slot; - } - - /** - * @brief Traverses a layer of a graph, to find the best place to insert a new node. - * Locks the nodes in the process, assuming other threads are updating neighbors lists. - * @return `true` if procedure succeeded, `false` if run out of memory. - */ - template - bool search_to_insert_( // - value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // - std::size_t start_slot, std::size_t new_slot, level_t level, std::size_t top_limit, - context_t& context) noexcept { - - visits_hash_set_t& visits = context.visits; - next_candidates_t& next = context.next_candidates; // pop min, push - top_candidates_t& top = context.top_candidates; // pop max, push - - visits.clear(); - next.clear(); - top.clear(); - if (!visits.reserve(config_.connectivity_base + 1u)) - return false; - - // Optional prefetching - if (!is_dummy()) - prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); - - distance_t radius = context.measure(query, citerator_at(start_slot), metric); - next.insert_reserved({-radius, static_cast(start_slot)}); - top.insert_reserved({radius, static_cast(start_slot)}); - visits.set(static_cast(start_slot)); - - while (!next.empty()) { - - candidate_t candidacy = next.top(); - if ((-candidacy.distance) > radius && top.size() == top_limit) - break; - - next.pop(); - context.iteration_cycles++; - - compressed_slot_t candidate_slot = candidacy.slot; - if (new_slot == candidate_slot) - continue; - node_t candidate_ref = node_at_(candidate_slot); - node_lock_t candidate_lock = node_lock_(candidate_slot); - neighbors_ref_t candidate_neighbors = neighbors_(candidate_ref, level); - - // Optional prefetching - if (!is_dummy()) { - candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; - prefetch(missing_candidates.begin(), missing_candidates.end()); - } - - // Assume the worst-case when reserving memory - if (!visits.reserve(visits.size() + candidate_neighbors.size())) - return false; - - for (compressed_slot_t successor_slot : candidate_neighbors) { - if (visits.set(successor_slot)) - continue; - - // node_lock_t successor_lock = node_lock_(successor_slot); - distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); - if (top.size() < top_limit || successor_dist < radius) { - // This can substantially grow our priority queue: - next.insert({-successor_dist, successor_slot}); - // This will automatically evict poor matches: - top.insert({successor_dist, successor_slot}, top_limit); - radius = top.top().distance; - } - } - } - return true; - } - - /** - * @brief Traverses the @b base layer of a graph, to find a close match. - * Doesn't lock any nodes, assuming read-only simultaneous access. - * @return `true` if procedure succeeded, `false` if run out of memory. - */ - template - bool search_to_find_in_base_( // - value_at&& query, metric_at&& metric, predicate_at&& predicate, prefetch_at&& prefetch, // - std::size_t start_slot, std::size_t expansion, context_t& context) const noexcept { - - visits_hash_set_t& visits = context.visits; - next_candidates_t& next = context.next_candidates; // pop min, push - top_candidates_t& top = context.top_candidates; // pop max, push - std::size_t const top_limit = expansion; - - visits.clear(); - next.clear(); - top.clear(); - if (!visits.reserve(config_.connectivity_base + 1u)) - return false; - - // Optional prefetching - if (!is_dummy()) - prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); - - distance_t radius = context.measure(query, citerator_at(start_slot), metric); - next.insert_reserved({-radius, static_cast(start_slot)}); - top.insert_reserved({radius, static_cast(start_slot)}); - visits.set(static_cast(start_slot)); - - while (!next.empty()) { - - candidate_t candidate = next.top(); - if ((-candidate.distance) > radius) - break; - - next.pop(); - context.iteration_cycles++; - - neighbors_ref_t candidate_neighbors = neighbors_base_(node_at_(candidate.slot)); - - // Optional prefetching - if (!is_dummy()) { - candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; - prefetch(missing_candidates.begin(), missing_candidates.end()); - } - - // Assume the worst-case when reserving memory - if (!visits.reserve(visits.size() + candidate_neighbors.size())) - return false; - - for (compressed_slot_t successor_slot : candidate_neighbors) { - if (visits.set(successor_slot)) - continue; - - distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); - if (top.size() < top_limit || successor_dist < radius) { - // This can substantially grow our priority queue: - next.insert({-successor_dist, successor_slot}); - if (!is_dummy()) - if (!predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) - continue; - - // This will automatically evict poor matches: - top.insert({successor_dist, successor_slot}, top_limit); - radius = top.top().distance; - } - } - } - - return true; - } - - /** - * @brief Iterates through all members, without actually touching the index. - */ - template - void search_exact_( // - value_at&& query, metric_at&& metric, predicate_at&& predicate, // - std::size_t count, context_t& context) const noexcept { - - top_candidates_t& top = context.top_candidates; - top.clear(); - top.reserve(count); - for (std::size_t i = 0; i != size(); ++i) { - if (!is_dummy()) - if (!predicate(at(i))) - continue; - - distance_t distance = context.measure(query, citerator_at(i), metric); - top.insert(candidate_t{distance, static_cast(i)}, count); - } - } - - /** - * @brief This algorithm from the original paper implements a heuristic, - * that massively reduces the number of connections a point has, - * to keep only the neighbors, that are from each other. - */ - template - candidates_view_t refine_( // - metric_at&& metric, // - std::size_t needed, top_candidates_t& top, context_t& context) const noexcept { - - top.sort_ascending(); - candidate_t* top_data = top.data(); - std::size_t const top_count = top.size(); - if (top_count < needed) - return {top_data, top_count}; - - std::size_t submitted_count = 1; - std::size_t consumed_count = 1; /// Always equal or greater than `submitted_count`. - while (submitted_count < needed && consumed_count < top_count) { - candidate_t candidate = top_data[consumed_count]; - bool good = true; - for (std::size_t idx = 0; idx < submitted_count; idx++) { - candidate_t submitted = top_data[idx]; - distance_t inter_result_dist = context.measure( // - citerator_at(candidate.slot), // - citerator_at(submitted.slot), // - metric); - if (inter_result_dist < candidate.distance) { - good = false; - break; - } - } - - if (good) { - top_data[submitted_count] = top_data[consumed_count]; - submitted_count++; - } - consumed_count++; - } - - top.shrink(submitted_count); - return {top_data, submitted_count}; - } + /** + * @brief Performs compaction on the whole HNSW index, purging some entries + * and links to them, while also generating a more efficient mapping, + * putting the more frequently used entries closer together. + * + * + * Scans the whole collection, removing the links leading towards + * banned entries. This essentially isolates some nodes from the rest + * of the graph, while keeping their outgoing links, in case the node + * is structurally relevant and has a crucial role in the index. + * It won't reclaim the memory. + * + * @param[in] allow_member Predicate to mark nodes for isolation. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ + template + void compact( // + values_at&& values, // + metric_at&& metric, // + slot_transition_at&& slot_transition, // + + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}, // + prefetch_at&& prefetch = prefetch_at{}) noexcept { + + // Export all the keys, slots, and levels. + // Partition them with the predicate. + // Sort the allowed entries in descending order of their level. + // Create a new array mapping old slots to the new ones (INT_MAX for deleted items). + struct slot_level_t { + compressed_slot_t old_slot; + compressed_slot_t cluster; + level_t level; + }; + using slot_level_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt slots_and_levels(size()); + + // Progress status + std::atomic do_tasks{true}; + std::atomic processed{0}; + std::size_t const total = 3 * slots_and_levels.size(); + + // For every bottom level node, determine its parent cluster + executor.dynamic(slots_and_levels.size(), [&](std::size_t thread_idx, std::size_t old_slot) { + context_t& context = contexts_[thread_idx]; + std::size_t cluster = search_for_one_( // + values[citerator_at(old_slot)], // + metric, prefetch, // + entry_slot_, max_level_, 0, context); + slots_and_levels[old_slot] = { // + static_cast(old_slot), // + static_cast(cluster), // + node_at_(old_slot).level()}; + ++processed; + if (thread_idx == 0) + do_tasks = progress(processed.load(), total); + return do_tasks.load(); + }); + if (!do_tasks.load()) + return; + + // Where the actual permutation happens: + std::sort(slots_and_levels.begin(), slots_and_levels.end(), [](slot_level_t const& a, slot_level_t const& b) { + return a.level == b.level ? a.cluster < b.cluster : a.level > b.level; + }); + + using size_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt old_slot_to_new(slots_and_levels.size()); + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) + old_slot_to_new[slots_and_levels[new_slot].old_slot] = new_slot; + + // Erase all the incoming links + buffer_gt reordered_nodes(slots_and_levels.size()); + tape_allocator_t reordered_tape; + + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { + std::size_t old_slot = slots_and_levels[new_slot].old_slot; + node_t old_node = node_at_(old_slot); + + std::size_t node_bytes = node_bytes_(old_node.level()); + byte_t* new_data = (byte_t*)reordered_tape.allocate(node_bytes); + node_t new_node{new_data}; + std::memcpy(new_data, old_node.tape(), node_bytes); + + for (level_t level = 0; level <= old_node.level(); ++level) + for (misaligned_ref_gt neighbor : neighbors_(new_node, level)) + neighbor = static_cast(old_slot_to_new[compressed_slot_t(neighbor)]); + + reordered_nodes[new_slot] = new_node; + if (!progress(++processed, total)) + return; + } + + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { + std::size_t old_slot = slots_and_levels[new_slot].old_slot; + slot_transition(node_at_(old_slot).ckey(), // + static_cast(old_slot), // + static_cast(new_slot)); + if (!progress(++processed, total)) + return; + } + + nodes_ = std::move(reordered_nodes); + tape_allocator_ = std::move(reordered_tape); + entry_slot_ = old_slot_to_new[entry_slot_]; + } + + /** + * @brief Scans the whole collection, removing the links leading towards + * banned entries. This essentially isolates some nodes from the rest + * of the graph, while keeping their outgoing links, in case the node + * is structurally relevant and has a crucial role in the index. + * It won't reclaim the memory. + * + * @param[in] allow_member Predicate to mark nodes for isolation. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ + template < // + typename allow_member_at = dummy_predicate_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + void isolate( // + allow_member_at&& allow_member, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) noexcept { + + // Progress status + std::atomic do_tasks{true}; + std::atomic processed{0}; + + // Erase all the incoming links + std::size_t nodes_count = size(); + executor.dynamic(nodes_count, [&](std::size_t thread_idx, std::size_t node_idx) { + node_t node = node_at_(node_idx); + for (level_t level = 0; level <= node.level(); ++level) { + neighbors_ref_t neighbors = neighbors_(node, level); + std::size_t old_size = neighbors.size(); + neighbors.clear(); + for (std::size_t i = 0; i != old_size; ++i) { + compressed_slot_t neighbor_slot = neighbors[i]; + node_t neighbor = node_at_(neighbor_slot); + if (allow_member(member_cref_t{neighbor.ckey(), neighbor_slot})) + neighbors.push_back(neighbor_slot); + } + } + ++processed; + if (thread_idx == 0) + do_tasks = progress(processed.load(), nodes_count); + return do_tasks.load(); + }); + + // At the end report the latest numbers, because the reporter thread may be finished earlier + progress(processed.load(), nodes_count); + } + + private: + inline static precomputed_constants_t precompute_(index_config_t const& config) noexcept { + precomputed_constants_t pre; + pre.inverse_log_connectivity = 1.0 / std::log(static_cast(config.connectivity)); + pre.neighbors_bytes = config.connectivity * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); + pre.neighbors_base_bytes = config.connectivity_base * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); + return pre; + } + + using span_bytes_t = span_gt; + + inline span_bytes_t node_bytes_(node_t node) const noexcept { return {node.tape(), node_bytes_(node.level())}; } + inline std::size_t node_bytes_(level_t level) const noexcept { + return node_head_bytes_() + node_neighbors_bytes_(level); + } + inline std::size_t node_neighbors_bytes_(node_t node) const noexcept { return node_neighbors_bytes_(node.level()); } + inline std::size_t node_neighbors_bytes_(level_t level) const noexcept { + return pre_.neighbors_base_bytes + pre_.neighbors_bytes * level; + } + + span_bytes_t node_malloc_(level_t level) noexcept { + std::size_t node_bytes = node_bytes_(level); + byte_t* data = (byte_t*)tape_allocator_.allocate(node_bytes); + return data ? span_bytes_t{data, node_bytes} : span_bytes_t{}; + } + + node_t node_make_(vector_key_t key, level_t level) noexcept { + span_bytes_t node_bytes = node_malloc_(level); + if (!node_bytes) + return {}; + + std::memset(node_bytes.data(), 0, node_bytes.size()); + node_t node{(byte_t*)node_bytes.data()}; + node.key(key); + node.level(level); + return node; + } + + node_t node_make_copy_(span_bytes_t old_bytes) noexcept { + byte_t* data = (byte_t*)tape_allocator_.allocate(old_bytes.size()); + if (!data) + return {}; + std::memcpy(data, old_bytes.data(), old_bytes.size()); + return node_t{data}; + } + + void node_free_(std::size_t idx) noexcept { + if (viewed_file_) + return; + + node_t& node = nodes_[idx]; + tape_allocator_.deallocate(node.tape(), node_bytes_(node).size()); + node = node_t{}; + } + + inline node_t node_at_(std::size_t idx) const noexcept { return nodes_[idx]; } + inline neighbors_ref_t neighbors_base_(node_t node) const noexcept { return {node.neighbors_tape()}; } + + inline neighbors_ref_t neighbors_non_base_(node_t node, level_t level) const noexcept { + return {node.neighbors_tape() + pre_.neighbors_base_bytes + (level - 1) * pre_.neighbors_bytes}; + } + + inline neighbors_ref_t neighbors_(node_t node, level_t level) const noexcept { + return level ? neighbors_non_base_(node, level) : neighbors_base_(node); + } + + struct node_lock_t { + nodes_mutexes_t& mutexes; + std::size_t slot; + inline ~node_lock_t() noexcept { mutexes.atomic_reset(slot); } + }; + + inline node_lock_t node_lock_(std::size_t slot) const noexcept { + while (nodes_mutexes_.atomic_set(slot)) + ; + return {nodes_mutexes_, slot}; + } + + template + void connect_node_across_levels_( // + value_at&& value, metric_at&& metric, prefetch_at&& prefetch, // + std::size_t node_slot, std::size_t entry_slot, level_t max_level, level_t target_level, // + index_update_config_t const& config, context_t& context) usearch_noexcept_m { + + // Go down the level, tracking only the closest match + std::size_t closest_slot = search_for_one_( // + value, metric, prefetch, // + entry_slot, max_level, target_level, context); + + // From `target_level` down perform proper extensive search + for (level_t level = (std::min)(target_level, max_level); level >= 0; --level) { + // TODO: Handle out of memory conditions + search_to_insert_(value, metric, prefetch, closest_slot, node_slot, level, config.expansion, context); + closest_slot = connect_new_node_(metric, node_slot, level, context); + reconnect_neighbor_nodes_(metric, node_slot, value, level, context); + } + } + + template + std::size_t connect_new_node_( // + metric_at&& metric, std::size_t new_slot, level_t level, context_t& context) usearch_noexcept_m { + + node_t new_node = node_at_(new_slot); + top_candidates_t& top = context.top_candidates; + + // Outgoing links from `new_slot`: + neighbors_ref_t new_neighbors = neighbors_(new_node, level); + { + usearch_assert_m(!new_neighbors.size(), "The newly inserted element should have blank link list"); + candidates_view_t top_view = refine_(metric, config_.connectivity, top, context); + + for (std::size_t idx = 0; idx != top_view.size(); idx++) { + usearch_assert_m(!new_neighbors[idx], "Possible memory corruption"); + usearch_assert_m(level <= node_at_(top_view[idx].slot).level(), "Linking to missing level"); + new_neighbors.push_back(top_view[idx].slot); + } + } + + return new_neighbors[0]; + } + + template + void reconnect_neighbor_nodes_( // + metric_at&& metric, std::size_t new_slot, value_at&& value, level_t level, + context_t& context) usearch_noexcept_m { + + node_t new_node = node_at_(new_slot); + top_candidates_t& top = context.top_candidates; + neighbors_ref_t new_neighbors = neighbors_(new_node, level); + + // Reverse links from the neighbors: + std::size_t const connectivity_max = level ? config_.connectivity : config_.connectivity_base; + for (compressed_slot_t close_slot : new_neighbors) { + if (close_slot == new_slot) + continue; + node_lock_t close_lock = node_lock_(close_slot); + node_t close_node = node_at_(close_slot); + + neighbors_ref_t close_header = neighbors_(close_node, level); + usearch_assert_m(close_header.size() <= connectivity_max, "Possible corruption"); + usearch_assert_m(close_slot != new_slot, "Self-loops are impossible"); + usearch_assert_m(level <= close_node.level(), "Linking to missing level"); + + // If `new_slot` is already present in the neighboring connections of `close_slot` + // then no need to modify any connections or run the heuristics. + if (close_header.size() < connectivity_max) { + close_header.push_back(static_cast(new_slot)); + continue; + } + + // To fit a new connection we need to drop an existing one. + top.clear(); + usearch_assert_m((top.reserve(close_header.size() + 1)), "The memory must have been reserved in `add`"); + top.insert_reserved( + {context.measure(value, citerator_at(close_slot), metric), static_cast(new_slot)}); + for (compressed_slot_t successor_slot : close_header) + top.insert_reserved( + {context.measure(citerator_at(close_slot), citerator_at(successor_slot), metric), successor_slot}); + + // Export the results: + close_header.clear(); + candidates_view_t top_view = refine_(metric, connectivity_max, top, context); + for (std::size_t idx = 0; idx != top_view.size(); idx++) + close_header.push_back(top_view[idx].slot); + } + } + + level_t choose_random_level_(std::default_random_engine& level_generator) const noexcept { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -std::log(distribution(level_generator)) * pre_.inverse_log_connectivity; + return (level_t)r; + } + + struct candidates_range_t; + class candidates_iterator_t { + friend struct candidates_range_t; + + index_gt const& index_; + neighbors_ref_t neighbors_; + visits_hash_set_t& visits_; + std::size_t current_; + + candidates_iterator_t& skip_missing() noexcept { + if (!visits_.size()) + return *this; + while (current_ != neighbors_.size()) { + compressed_slot_t neighbor_slot = neighbors_[current_]; + if (visits_.test(neighbor_slot)) + current_++; + else + break; + } + return *this; + } + + public: + using element_t = compressed_slot_t; + using iterator_category = std::forward_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = misaligned_ptr_gt; + using reference = misaligned_ref_gt; + + reference operator*() const noexcept { return slot(); } + candidates_iterator_t(index_gt const& index, neighbors_ref_t neighbors, visits_hash_set_t& visits, + std::size_t progress) noexcept + : index_(index), neighbors_(neighbors), visits_(visits), current_(progress) {} + candidates_iterator_t operator++(int) noexcept { + return candidates_iterator_t(index_, visits_, neighbors_, current_ + 1).skip_missing(); + } + candidates_iterator_t& operator++() noexcept { + ++current_; + skip_missing(); + return *this; + } + bool operator==(candidates_iterator_t const& other) noexcept { return current_ == other.current_; } + bool operator!=(candidates_iterator_t const& other) noexcept { return current_ != other.current_; } + + vector_key_t key() const noexcept { return index_->node_at_(slot()).key(); } + compressed_slot_t slot() const noexcept { return neighbors_[current_]; } + friend inline std::size_t get_slot(candidates_iterator_t const& it) noexcept { return it.slot(); } + friend inline vector_key_t get_key(candidates_iterator_t const& it) noexcept { return it.key(); } + }; + + struct candidates_range_t { + index_gt const& index; + neighbors_ref_t neighbors; + visits_hash_set_t& visits; + + candidates_iterator_t begin() const noexcept { + return candidates_iterator_t{index, neighbors, visits, 0}.skip_missing(); + } + candidates_iterator_t end() const noexcept { return {index, neighbors, visits, neighbors.size()}; } + }; + + template + std::size_t search_for_one_( // + value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // + std::size_t closest_slot, level_t begin_level, level_t end_level, context_t& context) const noexcept { + + visits_hash_set_t& visits = context.visits; + visits.clear(); + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(closest_slot), citerator_at(closest_slot + 1)); + + distance_t closest_dist = context.measure(query, citerator_at(closest_slot), metric); + for (level_t level = begin_level; level > end_level; --level) { + bool changed; + do { + changed = false; + node_lock_t closest_lock = node_lock_(closest_slot); + neighbors_ref_t closest_neighbors = neighbors_non_base_(node_at_(closest_slot), level); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, closest_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Actual traversal + for (compressed_slot_t candidate_slot : closest_neighbors) { + distance_t candidate_dist = context.measure(query, citerator_at(candidate_slot), metric); + if (candidate_dist < closest_dist) { + closest_dist = candidate_dist; + closest_slot = candidate_slot; + changed = true; + } + } + context.iteration_cycles++; + } while (changed); + } + return closest_slot; + } + + /** + * @brief Traverses a layer of a graph, to find the best place to insert a new node. + * Locks the nodes in the process, assuming other threads are updating neighbors lists. + * @return `true` if procedure succeeded, `false` if run out of memory. + */ + template + bool search_to_insert_( // + value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // + std::size_t start_slot, std::size_t new_slot, level_t level, std::size_t top_limit, + context_t& context) noexcept { + + visits_hash_set_t& visits = context.visits; + next_candidates_t& next = context.next_candidates; // pop min, push + top_candidates_t& top = context.top_candidates; // pop max, push + + visits.clear(); + next.clear(); + top.clear(); + if (!visits.reserve(config_.connectivity_base + 1u)) + return false; + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); + + distance_t radius = context.measure(query, citerator_at(start_slot), metric); + next.insert_reserved({-radius, static_cast(start_slot)}); + top.insert_reserved({radius, static_cast(start_slot)}); + visits.set(static_cast(start_slot)); + + while (!next.empty()) { + + candidate_t candidacy = next.top(); + if ((-candidacy.distance) > radius && top.size() == top_limit) + break; + + next.pop(); + context.iteration_cycles++; + + compressed_slot_t candidate_slot = candidacy.slot; + if (new_slot == candidate_slot) + continue; + node_t candidate_ref = node_at_(candidate_slot); + node_lock_t candidate_lock = node_lock_(candidate_slot); + neighbors_ref_t candidate_neighbors = neighbors_(candidate_ref, level); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Assume the worst-case when reserving memory + if (!visits.reserve(visits.size() + candidate_neighbors.size())) + return false; + + for (compressed_slot_t successor_slot : candidate_neighbors) { + if (visits.set(successor_slot)) + continue; + + // node_lock_t successor_lock = node_lock_(successor_slot); + distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); + if (top.size() < top_limit || successor_dist < radius) { + // This can substantially grow our priority queue: + next.insert({-successor_dist, successor_slot}); + // This will automatically evict poor matches: + top.insert({successor_dist, successor_slot}, top_limit); + radius = top.top().distance; + } + } + } + return true; + } + + /** + * @brief Traverses the @b base layer of a graph, to find a close match. + * Doesn't lock any nodes, assuming read-only simultaneous access. + * @return `true` if procedure succeeded, `false` if run out of memory. + */ + template + bool search_to_find_in_base_( // + value_at&& query, metric_at&& metric, predicate_at&& predicate, prefetch_at&& prefetch, // + std::size_t start_slot, std::size_t expansion, context_t& context) const usearch_noexcept_m { + + visits_hash_set_t& visits = context.visits; + next_candidates_t& next = context.next_candidates; // pop min, push + top_candidates_t& top = context.top_candidates; // pop max, push + std::size_t const top_limit = expansion; + + visits.clear(); + next.clear(); + top.clear(); + if (!visits.reserve(config_.connectivity_base + 1u)) + return false; + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(start_slot), citerator_at(start_slot + 1)); + + distance_t radius = context.measure(query, citerator_at(start_slot), metric); + usearch_assert_m(next.capacity(), "The `max_heap_gt` must have been reserved in the search entry point"); + next.insert_reserved({-radius, static_cast(start_slot)}); + visits.set(static_cast(start_slot)); + + // Don't populate the top list if the predicate is not satisfied + if (is_dummy() || predicate(member_cref_t{node_at_(start_slot).ckey(), start_slot})) { + usearch_assert_m(top.capacity(), + "The `sorted_buffer_gt` must have been reserved in the search entry point"); + top.insert_reserved({radius, static_cast(start_slot)}); + } + + while (!next.empty()) { + + candidate_t candidate = next.top(); + if ((-candidate.distance) > radius) + break; + + next.pop(); + context.iteration_cycles++; + + neighbors_ref_t candidate_neighbors = neighbors_base_(node_at_(candidate.slot)); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Assume the worst-case when reserving memory + if (!visits.reserve(visits.size() + candidate_neighbors.size())) + return false; + + for (compressed_slot_t successor_slot : candidate_neighbors) { + if (visits.set(successor_slot)) + continue; + + distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); + if (top.size() < top_limit || successor_dist < radius) { + // This can substantially grow our priority queue: + next.insert({-successor_dist, successor_slot}); + if (is_dummy() || + predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) + top.insert({successor_dist, successor_slot}, top_limit); + radius = top.top().distance; + } + } + } + + return true; + } + + /** + * @brief Iterates through all members, without actually touching the index. + */ + template + void search_exact_( // + value_at&& query, metric_at&& metric, predicate_at&& predicate, // + std::size_t count, context_t& context) const noexcept { + + top_candidates_t& top = context.top_candidates; + top.clear(); + top.reserve(count); + for (std::size_t i = 0; i != size(); ++i) { + if (!is_dummy()) + if (!predicate(at(i))) + continue; + + distance_t distance = context.measure(query, citerator_at(i), metric); + top.insert(candidate_t{distance, static_cast(i)}, count); + } + } + + /** + * @brief This algorithm from the original paper implements a heuristic, + * that massively reduces the number of connections a point has, + * to keep only the neighbors, that are from each other. + */ + template + candidates_view_t refine_( // + metric_at&& metric, // + std::size_t needed, top_candidates_t& top, context_t& context) const noexcept { + + top.sort_ascending(); + candidate_t* top_data = top.data(); + std::size_t const top_count = top.size(); + if (top_count < needed) + return {top_data, top_count}; + + std::size_t submitted_count = 1; + std::size_t consumed_count = 1; /// Always equal or greater than `submitted_count`. + while (submitted_count < needed && consumed_count < top_count) { + candidate_t candidate = top_data[consumed_count]; + bool good = true; + for (std::size_t idx = 0; idx < submitted_count; idx++) { + candidate_t submitted = top_data[idx]; + distance_t inter_result_dist = context.measure( // + citerator_at(candidate.slot), // + citerator_at(submitted.slot), // + metric); + if (inter_result_dist < candidate.distance) { + good = false; + break; + } + } + + if (good) { + top_data[submitted_count] = top_data[consumed_count]; + submitted_count++; + } + consumed_count++; + } + + top.shrink(submitted_count); + return {top_data, submitted_count}; + } }; struct join_result_t { - error_t error{}; - std::size_t intersection_size{}; - std::size_t engagements{}; - std::size_t visited_members{}; - std::size_t computed_distances{}; - - explicit operator bool() const noexcept { return !error; } - join_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } + error_t error{}; + std::size_t intersection_size{}; + std::size_t engagements{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + + explicit operator bool() const noexcept { return !error; } + join_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } }; /** -* @brief Adapts the Male-Optimal Stable Marriage algorithm for unequal sets -* to perform fast one-to-one matching between two large collections -* of vectors, using approximate nearest neighbors search. -* -* @param[inout] man_to_woman Container to map ::first keys to ::second. -* @param[inout] woman_to_man Container to map ::second keys to ::first. -* @param[in] executor Thread-pool to execute the job in parallel. -* @param[in] progress Callback to report the execution progress. -*/ + * @brief Adapts the Male-Optimal Stable Marriage algorithm for unequal sets + * to perform fast one-to-one matching between two large collections + * of vectors, using approximate nearest neighbors search. + * + * @param[inout] man_to_woman Container to map ::men keys to ::women. + * @param[inout] woman_to_man Container to map ::women keys to ::men. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ template < // - typename men_at, // - typename women_at, // - typename men_values_at, // - typename women_values_at, // - typename men_metric_at, // - typename women_metric_at, // - - typename man_to_woman_at = dummy_key_to_key_mapping_t, // - typename woman_to_man_at = dummy_key_to_key_mapping_t, // - typename executor_at = dummy_executor_t, // - typename progress_at = dummy_progress_t // - > + typename men_at, // + typename women_at, // + typename men_values_at, // + typename women_values_at, // + typename men_metric_at, // + typename women_metric_at, // + + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > static join_result_t join( // - men_at const& men, // - women_at const& women, // - men_values_at const& men_values, // - women_values_at const& women_values, // - men_metric_at&& men_metric, // - women_metric_at&& women_metric, // - - index_join_config_t config = {}, // - man_to_woman_at&& man_to_woman = man_to_woman_at{}, // - woman_to_man_at&& woman_to_man = woman_to_man_at{}, // - executor_at&& executor = executor_at{}, // - progress_at&& progress = progress_at{}) noexcept { - - if (women.size() < men.size()) - return unum::usearch::join( // - women, men, // - women_values, men_values, // - std::forward(women_metric), std::forward(men_metric), // - - config, // - std::forward(woman_to_man), // - std::forward(man_to_woman), // - std::forward(executor), // - std::forward(progress)); - - join_result_t result; - - // Sanity checks and argument validation: - if (&men == &women) - return result.failed("Can't join with itself, consider copying"); - - if (config.max_proposals == 0) - config.max_proposals = std::log(men.size()) + executor.size(); - - using proposals_count_t = std::uint16_t; - config.max_proposals = (std::min)(men.size(), config.max_proposals); - - using distance_t = typename men_at::distance_t; - using dynamic_allocator_traits_t = typename men_at::dynamic_allocator_traits_t; - using man_key_t = typename men_at::vector_key_t; - using woman_key_t = typename women_at::vector_key_t; - - // Use the `compressed_slot_t` type of the larger collection - using compressed_slot_t = typename women_at::compressed_slot_t; - using compressed_slot_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - using proposals_count_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - - // Create an atomic queue, as a ring structure, from/to which - // free men will be added/pulled. - std::mutex free_men_mutex{}; - ring_gt free_men; - free_men.reserve(men.size()); - for (std::size_t i = 0; i != men.size(); ++i) - free_men.push(static_cast(i)); - - // We are gonna need some temporary memory. - buffer_gt proposal_counts(men.size()); - buffer_gt man_to_woman_slots(men.size()); - buffer_gt woman_to_man_slots(women.size()); - if (!proposal_counts || !man_to_woman_slots || !woman_to_man_slots) - return result.failed("Can't temporary mappings"); - - compressed_slot_t missing_slot; - std::memset((void*)&missing_slot, 0xFF, sizeof(compressed_slot_t)); - std::memset((void*)man_to_woman_slots.data(), 0xFF, sizeof(compressed_slot_t) * men.size()); - std::memset((void*)woman_to_man_slots.data(), 0xFF, sizeof(compressed_slot_t) * women.size()); - std::memset(proposal_counts.data(), 0, sizeof(proposals_count_t) * men.size()); - - // Define locks, to limit concurrent accesses to `man_to_woman_slots` and `woman_to_man_slots`. - bitset_t men_locks(men.size()), women_locks(women.size()); - if (!men_locks || !women_locks) - return result.failed("Can't allocate locks"); - - std::atomic rounds{0}; - std::atomic engagements{0}; - std::atomic computed_distances{0}; - std::atomic visited_members{0}; - std::atomic atomic_error{nullptr}; - - // Concurrently process all the men - executor.parallel([&](std::size_t thread_idx) { - index_search_config_t search_config; - search_config.expansion = config.expansion; - search_config.exact = config.exact; - search_config.thread = thread_idx; - compressed_slot_t free_man_slot; - - // While there exist a free man who still has a woman to propose to. - while (!atomic_error.load(std::memory_order_relaxed)) { - std::size_t passed_rounds = 0; - std::size_t total_rounds = 0; - { - std::unique_lock pop_lock(free_men_mutex); - if (!free_men.try_pop(free_man_slot)) - // Primary exit path, we have exhausted the list of candidates - break; - passed_rounds = ++rounds; - total_rounds = passed_rounds + free_men.size(); - } - if (thread_idx == 0 && !progress(passed_rounds, total_rounds)) { - atomic_error.store("Terminated by user"); - break; - } - while (men_locks.atomic_set(free_man_slot)) - ; - - proposals_count_t& free_man_proposals = proposal_counts[free_man_slot]; - if (free_man_proposals >= config.max_proposals) - continue; - - // Find the closest woman, to whom this man hasn't proposed yet. - ++free_man_proposals; - auto candidates = women.search(men_values[free_man_slot], free_man_proposals, women_metric, search_config); - visited_members += candidates.visited_members; - computed_distances += candidates.computed_distances; - if (!candidates) { - atomic_error = candidates.error.release(); - break; - } - - auto match = candidates.back(); - auto woman = match.member; - while (women_locks.atomic_set(woman.slot)) - ; - - compressed_slot_t husband_slot = woman_to_man_slots[woman.slot]; - bool woman_is_free = husband_slot == missing_slot; - if (woman_is_free) { - // Engagement - man_to_woman_slots[free_man_slot] = woman.slot; - woman_to_man_slots[woman.slot] = free_man_slot; - engagements++; - } else { - distance_t distance_from_husband = women_metric(women_values[woman.slot], men_values[husband_slot]); - distance_t distance_from_candidate = match.distance; - if (distance_from_husband > distance_from_candidate) { - // Break-up - while (men_locks.atomic_set(husband_slot)) - ; - man_to_woman_slots[husband_slot] = missing_slot; - men_locks.atomic_reset(husband_slot); - - // New Engagement - man_to_woman_slots[free_man_slot] = woman.slot; - woman_to_man_slots[woman.slot] = free_man_slot; - engagements++; - - std::unique_lock push_lock(free_men_mutex); - free_men.push(husband_slot); - } else { - std::unique_lock push_lock(free_men_mutex); - free_men.push(free_man_slot); - } - } - - men_locks.atomic_reset(free_man_slot); - women_locks.atomic_reset(woman.slot); - } - }); - - if (atomic_error) - return result.failed(atomic_error.load()); - - // Export the "slots" into keys: - std::size_t intersection_size = 0; - for (std::size_t man_slot = 0; man_slot != men.size(); ++man_slot) { - compressed_slot_t woman_slot = man_to_woman_slots[man_slot]; - if (woman_slot != missing_slot) { - man_key_t man = men.at(man_slot).key; - woman_key_t woman = women.at(woman_slot).key; - man_to_woman[man] = woman; - woman_to_man[woman] = man; - intersection_size++; - } - } - - // Export stats - result.engagements = engagements; - result.intersection_size = intersection_size; - result.computed_distances = computed_distances; - result.visited_members = visited_members; - return result; + men_at const& men, // + women_at const& women, // + men_values_at const& men_values, // + women_values_at const& women_values, // + men_metric_at&& men_metric, // + women_metric_at&& women_metric, // + + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) noexcept { + + if (women.size() < men.size()) + return unum::usearch::join( // + women, men, // + women_values, men_values, // + std::forward(women_metric), std::forward(men_metric), // + + config, // + std::forward(woman_to_man), // + std::forward(man_to_woman), // + std::forward(executor), // + std::forward(progress)); + + join_result_t result; + + // Sanity checks and argument validation: + if (&men == &women) + return result.failed("Can't join with itself, consider copying"); + + if (config.max_proposals == 0) + config.max_proposals = std::log(men.size()) + executor.size(); + + using proposals_count_t = std::uint16_t; + config.max_proposals = (std::min)(men.size(), config.max_proposals); + + using distance_t = typename men_at::distance_t; + using dynamic_allocator_traits_t = typename men_at::dynamic_allocator_traits_t; + using man_key_t = typename men_at::vector_key_t; + using woman_key_t = typename women_at::vector_key_t; + + // Use the `compressed_slot_t` type of the larger collection + using compressed_slot_t = typename women_at::compressed_slot_t; + using compressed_slot_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + using proposals_count_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + // Create an atomic queue, as a ring structure, from/to which + // free men will be added/pulled. + std::mutex free_men_mutex{}; + ring_gt free_men; + free_men.reserve(men.size()); + for (std::size_t i = 0; i != men.size(); ++i) + free_men.push(static_cast(i)); + + // We are gonna need some temporary memory. + buffer_gt proposal_counts(men.size()); + buffer_gt man_to_woman_slots(men.size()); + buffer_gt woman_to_man_slots(women.size()); + if (!proposal_counts || !man_to_woman_slots || !woman_to_man_slots) + return result.failed("Can't temporary mappings"); + + compressed_slot_t missing_slot; + std::memset((void*)&missing_slot, 0xFF, sizeof(compressed_slot_t)); + std::memset((void*)man_to_woman_slots.data(), 0xFF, sizeof(compressed_slot_t) * men.size()); + std::memset((void*)woman_to_man_slots.data(), 0xFF, sizeof(compressed_slot_t) * women.size()); + std::memset(proposal_counts.data(), 0, sizeof(proposals_count_t) * men.size()); + + // Define locks, to limit concurrent accesses to `man_to_woman_slots` and `woman_to_man_slots`. + bitset_t men_locks(men.size()), women_locks(women.size()); + if (!men_locks || !women_locks) + return result.failed("Can't allocate locks"); + + std::atomic rounds{0}; + std::atomic engagements{0}; + std::atomic computed_distances{0}; + std::atomic visited_members{0}; + std::atomic atomic_error{nullptr}; + + // Concurrently process all the men + executor.parallel([&](std::size_t thread_idx) { + index_search_config_t search_config; + search_config.expansion = config.expansion; + search_config.exact = config.exact; + search_config.thread = thread_idx; + compressed_slot_t free_man_slot; + + // While there exist a free man who still has a woman to propose to. + while (!atomic_error.load(std::memory_order_relaxed)) { + std::size_t passed_rounds = 0; + std::size_t total_rounds = 0; + { + std::unique_lock pop_lock(free_men_mutex); + if (!free_men.try_pop(free_man_slot)) + // Primary exit path, we have exhausted the list of candidates + break; + passed_rounds = ++rounds; + total_rounds = passed_rounds + free_men.size(); + } + if (thread_idx == 0 && !progress(passed_rounds, total_rounds)) { + atomic_error.store("Terminated by user"); + break; + } + while (men_locks.atomic_set(free_man_slot)) + ; + + proposals_count_t& free_man_proposals = proposal_counts[free_man_slot]; + if (free_man_proposals >= config.max_proposals) + continue; + + // Find the closest woman, to whom this man hasn't proposed yet. + ++free_man_proposals; + auto candidates = women.search(men_values[free_man_slot], free_man_proposals, women_metric, search_config); + visited_members += candidates.visited_members; + computed_distances += candidates.computed_distances; + if (!candidates) { + atomic_error = candidates.error.release(); + break; + } + + auto match = candidates.back(); + auto woman = match.member; + while (women_locks.atomic_set(woman.slot)) + ; + + compressed_slot_t husband_slot = woman_to_man_slots[woman.slot]; + bool woman_is_free = husband_slot == missing_slot; + if (woman_is_free) { + // Engagement + man_to_woman_slots[free_man_slot] = woman.slot; + woman_to_man_slots[woman.slot] = free_man_slot; + engagements++; + } else { + distance_t distance_from_husband = women_metric(women_values[woman.slot], men_values[husband_slot]); + distance_t distance_from_candidate = match.distance; + if (distance_from_husband > distance_from_candidate) { + // Break-up + while (men_locks.atomic_set(husband_slot)) + ; + man_to_woman_slots[husband_slot] = missing_slot; + men_locks.atomic_reset(husband_slot); + + // New Engagement + man_to_woman_slots[free_man_slot] = woman.slot; + woman_to_man_slots[woman.slot] = free_man_slot; + engagements++; + + std::unique_lock push_lock(free_men_mutex); + free_men.push(husband_slot); + } else { + std::unique_lock push_lock(free_men_mutex); + free_men.push(free_man_slot); + } + } + + men_locks.atomic_reset(free_man_slot); + women_locks.atomic_reset(woman.slot); + } + }); + + if (atomic_error) + return result.failed(atomic_error.load()); + + // Export the "slots" into keys: + std::size_t intersection_size = 0; + for (std::size_t man_slot = 0; man_slot != men.size(); ++man_slot) { + compressed_slot_t woman_slot = man_to_woman_slots[man_slot]; + if (woman_slot != missing_slot) { + man_key_t man = men.at(man_slot).key; + woman_key_t woman = women.at(woman_slot).key; + man_to_woman[man] = woman; + woman_to_man[woman] = man; + intersection_size++; + } + } + + // Export stats + result.engagements = engagements; + result.intersection_size = intersection_size; + result.computed_distances = computed_distances; + result.visited_members = visited_members; + return result; } } // namespace usearch } // namespace unum -#endif \ No newline at end of file +#endif diff --git a/src/include/usearch/index_dense.hpp b/src/include/usearch/index_dense.hpp index 8512c45..563e039 100644 --- a/src/include/usearch/index_dense.hpp +++ b/src/include/usearch/index_dense.hpp @@ -1,3 +1,9 @@ +/** + * @file index_dense.hpp + * @author Ash Vardanian + * @brief Single-header Vector Search engine for equi-dimensional dense vectors. + * @date July 26, 2023 + */ #pragma once #include // `aligned_alloc` @@ -39,63 +45,63 @@ static_assert(sizeof(index_dense_head_buffer_t) == 64, "File header should be ex */ struct index_dense_head_t { - // Versioning: - using magic_t = char[7]; - using version_t = std::uint16_t; - - // Versioning: 7 + 2 * 3 = 13 bytes - char const* magic; - misaligned_ref_gt version_major; - misaligned_ref_gt version_minor; - misaligned_ref_gt version_patch; - - // Structural: 4 * 3 = 12 bytes - misaligned_ref_gt kind_metric; - misaligned_ref_gt kind_scalar; - misaligned_ref_gt kind_key; - misaligned_ref_gt kind_compressed_slot; - - // Population: 8 * 3 = 24 bytes - misaligned_ref_gt count_present; - misaligned_ref_gt count_deleted; - misaligned_ref_gt dimensions; - misaligned_ref_gt multi; - - index_dense_head_t(byte_t* ptr) noexcept - : magic((char const*)exchange(ptr, ptr + sizeof(magic_t))), // - version_major(exchange(ptr, ptr + sizeof(version_t))), // - version_minor(exchange(ptr, ptr + sizeof(version_t))), // - version_patch(exchange(ptr, ptr + sizeof(version_t))), // - kind_metric(exchange(ptr, ptr + sizeof(metric_kind_t))), // - kind_scalar(exchange(ptr, ptr + sizeof(scalar_kind_t))), // - kind_key(exchange(ptr, ptr + sizeof(scalar_kind_t))), // - kind_compressed_slot(exchange(ptr, ptr + sizeof(scalar_kind_t))), // - count_present(exchange(ptr, ptr + sizeof(std::uint64_t))), // - count_deleted(exchange(ptr, ptr + sizeof(std::uint64_t))), // - dimensions(exchange(ptr, ptr + sizeof(std::uint64_t))), // - multi(exchange(ptr, ptr + sizeof(bool))) {} + // Versioning: + using magic_t = char[7]; + using version_t = std::uint16_t; + + // Versioning: 7 + 2 * 3 = 13 bytes + char const* magic; + misaligned_ref_gt version_major; + misaligned_ref_gt version_minor; + misaligned_ref_gt version_patch; + + // Structural: 4 * 3 = 12 bytes + misaligned_ref_gt kind_metric; + misaligned_ref_gt kind_scalar; + misaligned_ref_gt kind_key; + misaligned_ref_gt kind_compressed_slot; + + // Population: 8 * 3 = 24 bytes + misaligned_ref_gt count_present; + misaligned_ref_gt count_deleted; + misaligned_ref_gt dimensions; + misaligned_ref_gt multi; + + index_dense_head_t(byte_t* ptr) noexcept + : magic((char const*)exchange(ptr, ptr + sizeof(magic_t))), // + version_major(exchange(ptr, ptr + sizeof(version_t))), // + version_minor(exchange(ptr, ptr + sizeof(version_t))), // + version_patch(exchange(ptr, ptr + sizeof(version_t))), // + kind_metric(exchange(ptr, ptr + sizeof(metric_kind_t))), // + kind_scalar(exchange(ptr, ptr + sizeof(scalar_kind_t))), // + kind_key(exchange(ptr, ptr + sizeof(scalar_kind_t))), // + kind_compressed_slot(exchange(ptr, ptr + sizeof(scalar_kind_t))), // + count_present(exchange(ptr, ptr + sizeof(std::uint64_t))), // + count_deleted(exchange(ptr, ptr + sizeof(std::uint64_t))), // + dimensions(exchange(ptr, ptr + sizeof(std::uint64_t))), // + multi(exchange(ptr, ptr + sizeof(bool))) {} }; struct index_dense_head_result_t { - index_dense_head_buffer_t buffer; - index_dense_head_t head; - error_t error; + index_dense_head_buffer_t buffer; + index_dense_head_t head; + error_t error; - explicit operator bool() const noexcept { return !error; } - index_dense_head_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } + explicit operator bool() const noexcept { return !error; } + index_dense_head_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } }; struct index_dense_config_t : public index_config_t { - std::size_t expansion_add = default_expansion_add(); - std::size_t expansion_search = default_expansion_search(); - bool exclude_vectors = false; - bool multi = false; + std::size_t expansion_add = default_expansion_add(); + std::size_t expansion_search = default_expansion_search(); + bool exclude_vectors = false; + bool multi = false; - /** + /** * @brief Allows you to reduce RAM consumption by avoiding * reverse-indexing keys-to-vectors, and only keeping * the vectors-to-keys mappings. @@ -103,64 +109,64 @@ struct index_dense_config_t : public index_config_t { * ! This configuration parameter doesn't affect the serialized file, * ! and is not preserved between runs. Makes sense for small vector * ! representations that fit ina single cache line. - */ - bool enable_key_lookups = true; + */ + bool enable_key_lookups = true; - index_dense_config_t(index_config_t base) noexcept : index_config_t(base) {} + index_dense_config_t(index_config_t base) noexcept : index_config_t(base) {} - index_dense_config_t(std::size_t c = default_connectivity(), std::size_t ea = default_expansion_add(), - std::size_t es = default_expansion_search()) noexcept - : index_config_t(c), expansion_add(ea ? ea : default_expansion_add()), - expansion_search(es ? es : default_expansion_search()) {} + index_dense_config_t(std::size_t c = default_connectivity(), std::size_t ea = default_expansion_add(), + std::size_t es = default_expansion_search()) noexcept + : index_config_t(c), expansion_add(ea ? ea : default_expansion_add()), + expansion_search(es ? es : default_expansion_search()) {} }; struct index_dense_clustering_config_t { - std::size_t min_clusters = 0; - std::size_t max_clusters = 0; - enum mode_t { - merge_smallest_k, - merge_closest_k, - } mode = merge_smallest_k; + std::size_t min_clusters = 0; + std::size_t max_clusters = 0; + enum mode_t { + merge_smallest_k, + merge_closest_k, + } mode = merge_smallest_k; }; struct index_dense_serialization_config_t { - bool exclude_vectors = false; - bool use_64_bit_dimensions = false; + bool exclude_vectors = false; + bool use_64_bit_dimensions = false; }; struct index_dense_copy_config_t : public index_copy_config_t { - bool force_vector_copy = true; + bool force_vector_copy = true; - index_dense_copy_config_t() = default; - index_dense_copy_config_t(index_copy_config_t base) noexcept : index_copy_config_t(base) {} + index_dense_copy_config_t() = default; + index_dense_copy_config_t(index_copy_config_t base) noexcept : index_copy_config_t(base) {} }; struct index_dense_metadata_result_t { - index_dense_serialization_config_t config; - index_dense_head_buffer_t head_buffer; - index_dense_head_t head; - error_t error; - - explicit operator bool() const noexcept { return !error; } - index_dense_metadata_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - - index_dense_metadata_result_t() noexcept : config(), head_buffer(), head(head_buffer), error() {} - - index_dense_metadata_result_t(index_dense_metadata_result_t&& other) noexcept - : config(), head_buffer(), head(head_buffer), error(std::move(other.error)) { - std::memcpy(&config, &other.config, sizeof(other.config)); - std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); - } - - index_dense_metadata_result_t& operator=(index_dense_metadata_result_t&& other) noexcept { - std::memcpy(&config, &other.config, sizeof(other.config)); - std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); - error = std::move(other.error); - return *this; - } + index_dense_serialization_config_t config; + index_dense_head_buffer_t head_buffer; + index_dense_head_t head; + error_t error; + + explicit operator bool() const noexcept { return !error; } + index_dense_metadata_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + + index_dense_metadata_result_t() noexcept : config(), head_buffer(), head(head_buffer), error() {} + + index_dense_metadata_result_t(index_dense_metadata_result_t&& other) noexcept + : config(), head_buffer(), head(head_buffer), error(std::move(other.error)) { + std::memcpy(&config, &other.config, sizeof(other.config)); + std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); + } + + index_dense_metadata_result_t& operator=(index_dense_metadata_result_t&& other) noexcept { + std::memcpy(&config, &other.config, sizeof(other.config)); + std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); + error = std::move(other.error); + return *this; + } }; /** @@ -168,65 +174,65 @@ struct index_dense_metadata_result_t { * without loading it or mapping the whole binary file. */ inline index_dense_metadata_result_t index_dense_metadata_from_path(char const* file_path) noexcept { - index_dense_metadata_result_t result; - std::unique_ptr file(std::fopen(file_path, "rb"), &std::fclose); - if (!file) - return result.failed(std::strerror(errno)); - - // Read the header - std::size_t read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); - if (!read) - return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); - - // Check if the file immediately starts with the index, instead of vectors - result.config.exclude_vectors = true; - if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) - return result; - - if (std::fseek(file.get(), 0L, SEEK_END) != 0) - return result.failed("Can't infer file size"); - - // Check if it starts with 32-bit - std::size_t const file_size = std::ftell(file.get()); - - std::uint32_t dimensions_u32[2]{0}; - std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); - std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); - - std::uint64_t dimensions_u64[2]{0}; - std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); - std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); - - // Check if it starts with 32-bit - if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) { - if (std::fseek(file.get(), static_cast(offset_if_u32), SEEK_SET) != 0) - return result.failed(std::strerror(errno)); - read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); - if (!read) - return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); - - result.config.exclude_vectors = false; - result.config.use_64_bit_dimensions = false; - if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) - return result; - } - - // Check if it starts with 64-bit - if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) { - if (std::fseek(file.get(), static_cast(offset_if_u64), SEEK_SET) != 0) - return result.failed(std::strerror(errno)); - read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); - if (!read) - return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); - - // Check if it starts with 64-bit - result.config.exclude_vectors = false; - result.config.use_64_bit_dimensions = true; - if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) - return result; - } - - return result.failed("Not a dense USearch index!"); + index_dense_metadata_result_t result; + std::unique_ptr file(std::fopen(file_path, "rb"), &std::fclose); + if (!file) + return result.failed(std::strerror(errno)); + + // Read the header + std::size_t read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + // Check if the file immediately starts with the index, instead of vectors + result.config.exclude_vectors = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + + if (std::fseek(file.get(), 0L, SEEK_END) != 0) + return result.failed("Can't infer file size"); + + // Check if it starts with 32-bit + std::size_t const file_size = std::ftell(file.get()); + + std::uint32_t dimensions_u32[2]{0}; + std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); + std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); + + std::uint64_t dimensions_u64[2]{0}; + std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); + std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); + + // Check if it starts with 32-bit + if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) { + if (std::fseek(file.get(), static_cast(offset_if_u32), SEEK_SET) != 0) + return result.failed(std::strerror(errno)); + read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = false; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + // Check if it starts with 64-bit + if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) { + if (std::fseek(file.get(), static_cast(offset_if_u64), SEEK_SET) != 0) + return result.failed(std::strerror(errno)); + read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + // Check if it starts with 64-bit + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + return result.failed("Not a dense USearch index!"); } /** @@ -234,49 +240,49 @@ inline index_dense_metadata_result_t index_dense_metadata_from_path(char const* */ inline index_dense_metadata_result_t index_dense_metadata_from_buffer(memory_mapped_file_t file, std::size_t offset = 0) noexcept { - index_dense_metadata_result_t result; - - // Read the header - if (offset + sizeof(index_dense_head_buffer_t) >= file.size()) - return result.failed("End of file reached!"); - - byte_t* const file_data = file.data() + offset; - std::size_t const file_size = file.size() - offset; - std::memcpy(&result.head_buffer, file_data, sizeof(index_dense_head_buffer_t)); - - // Check if the file immediately starts with the index, instead of vectors - result.config.exclude_vectors = true; - if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) - return result; - - // Check if it starts with 32-bit - std::uint32_t dimensions_u32[2]{0}; - std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); - std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); - - std::uint64_t dimensions_u64[2]{0}; - std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); - std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); - - // Check if it starts with 32-bit - if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) { - std::memcpy(&result.head_buffer, file_data + offset_if_u32, sizeof(index_dense_head_buffer_t)); - result.config.exclude_vectors = false; - result.config.use_64_bit_dimensions = false; - if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) - return result; - } - - // Check if it starts with 64-bit - if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) { - std::memcpy(&result.head_buffer, file_data + offset_if_u64, sizeof(index_dense_head_buffer_t)); - result.config.exclude_vectors = false; - result.config.use_64_bit_dimensions = true; - if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) - return result; - } - - return result.failed("Not a dense USearch index!"); + index_dense_metadata_result_t result; + + // Read the header + if (offset + sizeof(index_dense_head_buffer_t) >= file.size()) + return result.failed("End of file reached!"); + + byte_t* const file_data = file.data() + offset; + std::size_t const file_size = file.size() - offset; + std::memcpy(&result.head_buffer, file_data, sizeof(index_dense_head_buffer_t)); + + // Check if the file immediately starts with the index, instead of vectors + result.config.exclude_vectors = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + + // Check if it starts with 32-bit + std::uint32_t dimensions_u32[2]{0}; + std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); + std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); + + std::uint64_t dimensions_u64[2]{0}; + std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); + std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); + + // Check if it starts with 32-bit + if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) { + std::memcpy(&result.head_buffer, file_data + offset_if_u32, sizeof(index_dense_head_buffer_t)); + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = false; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + // Check if it starts with 64-bit + if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) { + std::memcpy(&result.head_buffer, file_data + offset_if_u64, sizeof(index_dense_head_buffer_t)); + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + return result.failed("Not a dense USearch index!"); } /** @@ -297,324 +303,330 @@ inline index_dense_metadata_result_t index_dense_metadata_from_buffer(memory_map */ template // class index_dense_gt { -public: - using vector_key_t = key_at; - using key_t = vector_key_t; - using compressed_slot_t = compressed_slot_at; - using distance_t = distance_punned_t; - using metric_t = metric_punned_t; + public: + using vector_key_t = key_at; + using key_t = vector_key_t; + using compressed_slot_t = compressed_slot_at; + using distance_t = distance_punned_t; + using metric_t = metric_punned_t; - using member_ref_t = member_ref_gt; - using member_cref_t = member_cref_gt; + using member_ref_t = member_ref_gt; + using member_cref_t = member_cref_gt; - using head_t = index_dense_head_t; - using head_buffer_t = index_dense_head_buffer_t; - using head_result_t = index_dense_head_result_t; + using head_t = index_dense_head_t; + using head_buffer_t = index_dense_head_buffer_t; + using head_result_t = index_dense_head_result_t; - using serialization_config_t = index_dense_serialization_config_t; + using serialization_config_t = index_dense_serialization_config_t; - using dynamic_allocator_t = aligned_allocator_gt; - using tape_allocator_t = memory_mapping_allocator_gt<64>; + using dynamic_allocator_t = aligned_allocator_gt; + using tape_allocator_t = memory_mapping_allocator_gt<64>; -private: - /// @brief Schema: input buffer, bytes in input buffer, output buffer. - using cast_t = std::function; - /// @brief Punned index. - using index_t = index_gt< // - distance_t, vector_key_t, compressed_slot_t, // - dynamic_allocator_t, tape_allocator_t>; - using index_allocator_t = aligned_allocator_gt; + private: + /// @brief Schema: input buffer, bytes in input buffer, output buffer. + using cast_t = std::function; + /// @brief Punned index. + using index_t = index_gt< // + distance_t, vector_key_t, compressed_slot_t, // + dynamic_allocator_t, tape_allocator_t>; + using index_allocator_t = aligned_allocator_gt; - using member_iterator_t = typename index_t::member_iterator_t; - using member_citerator_t = typename index_t::member_citerator_t; + using member_iterator_t = typename index_t::member_iterator_t; + using member_citerator_t = typename index_t::member_citerator_t; - /// @brief Punned metric object. - class metric_proxy_t { - index_dense_gt const* index_ = nullptr; + /// @brief Punned metric object. + class metric_proxy_t { + index_dense_gt const* index_ = nullptr; - public: - metric_proxy_t(index_dense_gt const& index) noexcept : index_(&index) {} + public: + metric_proxy_t(index_dense_gt const& index) noexcept : index_(&index) {} - inline distance_t operator()(byte_t const* a, member_cref_t b) const noexcept { return f(a, v(b)); } - inline distance_t operator()(member_cref_t a, member_cref_t b) const noexcept { return f(v(a), v(b)); } + inline distance_t operator()(byte_t const* a, member_cref_t b) const noexcept { return f(a, v(b)); } + inline distance_t operator()(member_cref_t a, member_cref_t b) const noexcept { return f(v(a), v(b)); } - inline distance_t operator()(byte_t const* a, member_citerator_t b) const noexcept { return f(a, v(b)); } - inline distance_t operator()(member_citerator_t a, member_citerator_t b) const noexcept { - return f(v(a), v(b)); - } + inline distance_t operator()(byte_t const* a, member_citerator_t b) const noexcept { return f(a, v(b)); } + inline distance_t operator()(member_citerator_t a, member_citerator_t b) const noexcept { + return f(v(a), v(b)); + } - inline distance_t operator()(byte_t const* a, byte_t const* b) const noexcept { return f(a, b); } + inline distance_t operator()(byte_t const* a, byte_t const* b) const noexcept { return f(a, b); } - inline byte_t const* v(member_cref_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } - inline byte_t const* v(member_citerator_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } - inline distance_t f(byte_t const* a, byte_t const* b) const noexcept { return index_->metric_(a, b); } - }; + inline byte_t const* v(member_cref_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } + inline byte_t const* v(member_citerator_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } + inline distance_t f(byte_t const* a, byte_t const* b) const noexcept { return index_->metric_(a, b); } + }; - index_dense_config_t config_; - index_t* typed_ = nullptr; + index_dense_config_t config_; + index_t* typed_ = nullptr; - mutable std::vector cast_buffer_; - struct casts_t { - cast_t from_b1x8; - cast_t from_i8; - cast_t from_f16; - cast_t from_f32; - cast_t from_f64; + mutable std::vector cast_buffer_; + struct casts_t { + cast_t from_b1x8; + cast_t from_i8; + cast_t from_f16; + cast_t from_f32; + cast_t from_f64; - cast_t to_b1x8; - cast_t to_i8; - cast_t to_f16; - cast_t to_f32; - cast_t to_f64; - } casts_; + cast_t to_b1x8; + cast_t to_i8; + cast_t to_f16; + cast_t to_f32; + cast_t to_f64; + } casts_; - /// @brief An instance of a potentially stateful `metric_t` used to initialize copies and forks. - metric_t metric_; + /// @brief An instance of a potentially stateful `metric_t` used to initialize copies and forks. + metric_t metric_; - using vectors_tape_allocator_t = memory_mapping_allocator_gt<8>; - /// @brief Allocator for the copied vectors, aligned to widest double-precision scalars. - vectors_tape_allocator_t vectors_tape_allocator_; + using vectors_tape_allocator_t = memory_mapping_allocator_gt<8>; + /// @brief Allocator for the copied vectors, aligned to widest double-precision scalars. + vectors_tape_allocator_t vectors_tape_allocator_; - /// @brief For every managed `compressed_slot_t` stores a pointer to the allocated vector copy. - mutable std::vector vectors_lookup_; + /// @brief For every managed `compressed_slot_t` stores a pointer to the allocated vector copy. + mutable std::vector vectors_lookup_; - /// @brief Originally forms and array of integers [0, threads], marking all - mutable std::vector available_threads_; + /// @brief Originally forms and array of integers [0, threads], marking all + mutable std::vector available_threads_; - /// @brief Mutex, controlling concurrent access to `available_threads_`. - mutable std::mutex available_threads_mutex_; + /// @brief Mutex, controlling concurrent access to `available_threads_`. + mutable std::mutex available_threads_mutex_; #if defined(USEARCH_DEFINED_CPP17) - using shared_mutex_t = std::shared_mutex; + using shared_mutex_t = std::shared_mutex; #else - using shared_mutex_t = unfair_shared_mutex_t; + using shared_mutex_t = unfair_shared_mutex_t; #endif - using shared_lock_t = shared_lock_gt; - using unique_lock_t = std::unique_lock; - - struct key_and_slot_t { - vector_key_t key; - compressed_slot_t slot; - - bool any_slot() const { return slot == default_free_value(); } - static key_and_slot_t any_slot(vector_key_t key) { return {key, default_free_value()}; } - }; - - struct lookup_key_hash_t { - using is_transparent = void; - std::size_t operator()(key_and_slot_t const& k) const noexcept { return std::hash{}(k.key); } - std::size_t operator()(vector_key_t const& k) const noexcept { return std::hash{}(k); } - }; - - struct lookup_key_same_t { - using is_transparent = void; - bool operator()(key_and_slot_t const& a, vector_key_t const& b) const noexcept { return a.key == b; } - bool operator()(vector_key_t const& a, key_and_slot_t const& b) const noexcept { return a == b.key; } - bool operator()(key_and_slot_t const& a, key_and_slot_t const& b) const noexcept { return a.key == b.key; } - }; - - /// @brief Multi-Map from keys to IDs, and allocated vectors. - flat_hash_multi_set_gt slot_lookup_; - - /// @brief Mutex, controlling concurrent access to `slot_lookup_`. - mutable shared_mutex_t slot_lookup_mutex_; - - /// @brief Ring-shaped queue of deleted entries, to be reused on future insertions. - ring_gt free_keys_; - - /// @brief Mutex, controlling concurrent access to `free_keys_`. - mutable std::mutex free_keys_mutex_; - - /// @brief A constant for the reserved key value, used to mark deleted entries. - vector_key_t free_key_ = default_free_value(); - -public: - using search_result_t = typename index_t::search_result_t; - using cluster_result_t = typename index_t::cluster_result_t; - using add_result_t = typename index_t::add_result_t; - using stats_t = typename index_t::stats_t; - using match_t = typename index_t::match_t; - - index_dense_gt() = default; - index_dense_gt(index_dense_gt&& other) - : config_(std::move(other.config_)), - - typed_(exchange(other.typed_, nullptr)), // - cast_buffer_(std::move(other.cast_buffer_)), // - casts_(std::move(other.casts_)), // - metric_(std::move(other.metric_)), // - - vectors_tape_allocator_(std::move(other.vectors_tape_allocator_)), // - vectors_lookup_(std::move(other.vectors_lookup_)), // - - available_threads_(std::move(other.available_threads_)), // - slot_lookup_(std::move(other.slot_lookup_)), // - free_keys_(std::move(other.free_keys_)), // - free_key_(std::move(other.free_key_)) {} // - - index_dense_gt& operator=(index_dense_gt&& other) { - swap(other); - return *this; - } - - /** + using shared_lock_t = shared_lock_gt; + using unique_lock_t = std::unique_lock; + + struct key_and_slot_t { + vector_key_t key; + compressed_slot_t slot; + + bool any_slot() const { return slot == default_free_value(); } + static key_and_slot_t any_slot(vector_key_t key) { return {key, default_free_value()}; } + }; + + struct lookup_key_hash_t { + using is_transparent = void; + std::size_t operator()(key_and_slot_t const& k) const noexcept { return std::hash{}(k.key); } + std::size_t operator()(vector_key_t const& k) const noexcept { return std::hash{}(k); } + }; + + struct lookup_key_same_t { + using is_transparent = void; + bool operator()(key_and_slot_t const& a, vector_key_t const& b) const noexcept { return a.key == b; } + bool operator()(vector_key_t const& a, key_and_slot_t const& b) const noexcept { return a == b.key; } + bool operator()(key_and_slot_t const& a, key_and_slot_t const& b) const noexcept { return a.key == b.key; } + }; + + /// @brief Multi-Map from keys to IDs, and allocated vectors. + flat_hash_multi_set_gt slot_lookup_; + + /// @brief Mutex, controlling concurrent access to `slot_lookup_`. + mutable shared_mutex_t slot_lookup_mutex_; + + /// @brief Ring-shaped queue of deleted entries, to be reused on future insertions. + ring_gt free_keys_; + + /// @brief Mutex, controlling concurrent access to `free_keys_`. + mutable std::mutex free_keys_mutex_; + + /// @brief A constant for the reserved key value, used to mark deleted entries. + vector_key_t free_key_ = default_free_value(); + + public: + using search_result_t = typename index_t::search_result_t; + using cluster_result_t = typename index_t::cluster_result_t; + using add_result_t = typename index_t::add_result_t; + using stats_t = typename index_t::stats_t; + using match_t = typename index_t::match_t; + + index_dense_gt() = default; + index_dense_gt(index_dense_gt&& other) + : config_(std::move(other.config_)), + + typed_(exchange(other.typed_, nullptr)), // + cast_buffer_(std::move(other.cast_buffer_)), // + casts_(std::move(other.casts_)), // + metric_(std::move(other.metric_)), // + + vectors_tape_allocator_(std::move(other.vectors_tape_allocator_)), // + vectors_lookup_(std::move(other.vectors_lookup_)), // + + available_threads_(std::move(other.available_threads_)), // + slot_lookup_(std::move(other.slot_lookup_)), // + free_keys_(std::move(other.free_keys_)), // + free_key_(std::move(other.free_key_)) {} // + + index_dense_gt& operator=(index_dense_gt&& other) { + swap(other); + return *this; + } + + /** * @brief Swaps the contents of this index with another index. * @param other The other index to swap with. - */ - void swap(index_dense_gt& other) { - std::swap(config_, other.config_); - - std::swap(typed_, other.typed_); - std::swap(cast_buffer_, other.cast_buffer_); - std::swap(casts_, other.casts_); - std::swap(metric_, other.metric_); - - std::swap(vectors_tape_allocator_, other.vectors_tape_allocator_); - std::swap(vectors_lookup_, other.vectors_lookup_); - - std::swap(available_threads_, other.available_threads_); - std::swap(slot_lookup_, other.slot_lookup_); - std::swap(free_keys_, other.free_keys_); - std::swap(free_key_, other.free_key_); - } - - ~index_dense_gt() { - if (typed_) - typed_->~index_t(); - index_allocator_t{}.deallocate(typed_, 1); - typed_ = nullptr; - } - - /** + */ + void swap(index_dense_gt& other) { + std::swap(config_, other.config_); + + std::swap(typed_, other.typed_); + std::swap(cast_buffer_, other.cast_buffer_); + std::swap(casts_, other.casts_); + std::swap(metric_, other.metric_); + + std::swap(vectors_tape_allocator_, other.vectors_tape_allocator_); + std::swap(vectors_lookup_, other.vectors_lookup_); + + std::swap(available_threads_, other.available_threads_); + std::swap(slot_lookup_, other.slot_lookup_); + std::swap(free_keys_, other.free_keys_); + std::swap(free_key_, other.free_key_); + } + + ~index_dense_gt() { + if (typed_) + typed_->~index_t(); + index_allocator_t{}.deallocate(typed_, 1); + typed_ = nullptr; + } + + /** * @brief Constructs an instance of ::index_dense_gt. * @param[in] metric One of the provided or an @b ad-hoc metric, type-punned. * @param[in] config The index configuration (optional). * @param[in] free_key The key used for freed vectors (optional). * @return An instance of ::index_dense_gt. - */ - static index_dense_gt make( // - metric_t metric, // - index_dense_config_t config = {}, // - vector_key_t free_key = default_free_value()) { - - scalar_kind_t scalar_kind = metric.scalar_kind(); - std::size_t hardware_threads = std::thread::hardware_concurrency(); - - index_dense_gt result; - result.config_ = config; - result.cast_buffer_.resize(hardware_threads * metric.bytes_per_vector()); - result.casts_ = make_casts_(scalar_kind); - result.metric_ = metric; - result.free_key_ = free_key; - - // Fill the thread IDs. - result.available_threads_.resize(hardware_threads); - std::iota(result.available_threads_.begin(), result.available_threads_.end(), 0ul); - - // Available since C11, but only C++17, so we use the C version. - index_t* raw = index_allocator_t{}.allocate(1); - new (raw) index_t(config); - result.typed_ = raw; - return result; - } - - static index_dense_gt make(char const* path, bool view = false) { - index_dense_metadata_result_t meta = index_dense_metadata_from_path(path); - if (!meta) - return {}; - metric_punned_t metric(meta.head.dimensions, meta.head.kind_metric, meta.head.kind_scalar); - index_dense_gt result = make(metric); - if (!result) - return result; - if (view) - result.view(path); - else - result.load(path); - return result; - } - - explicit operator bool() const { return typed_; } - std::size_t connectivity() const { return typed_->connectivity(); } - std::size_t size() const { return typed_->size() - free_keys_.size(); } - std::size_t capacity() const { return typed_->capacity(); } - std::size_t max_level() const noexcept { return typed_->max_level(); } - index_dense_config_t const& config() const { return config_; } - index_limits_t const& limits() const { return typed_->limits(); } - bool multi() const { return config_.multi; } - - // The metric and its properties - metric_t const& metric() const { return metric_; } - void change_metric(metric_t metric) { metric_ = std::move(metric); } - - scalar_kind_t scalar_kind() const noexcept { return metric_.scalar_kind(); } - std::size_t bytes_per_vector() const noexcept { return metric_.bytes_per_vector(); } - std::size_t scalar_words() const noexcept { return metric_.scalar_words(); } - std::size_t dimensions() const noexcept { return metric_.dimensions(); } - - // Fetching and changing search criteria - std::size_t expansion_add() const { return config_.expansion_add; } - std::size_t expansion_search() const { return config_.expansion_search; } - void change_expansion_add(std::size_t n) { config_.expansion_add = n; } - void change_expansion_search(std::size_t n) { config_.expansion_search = n; } - - member_citerator_t cbegin() const { return typed_->cbegin(); } - member_citerator_t cend() const { return typed_->cend(); } - member_citerator_t begin() const { return typed_->begin(); } - member_citerator_t end() const { return typed_->end(); } - member_iterator_t begin() { return typed_->begin(); } - member_iterator_t end() { return typed_->end(); } - - stats_t stats() const { return typed_->stats(); } - stats_t stats(std::size_t level) const { return typed_->stats(level); } - stats_t stats(stats_t* stats_per_level, std::size_t max_level) const { - return typed_->stats(stats_per_level, max_level); - } - - dynamic_allocator_t const& allocator() const { return typed_->dynamic_allocator(); } - vector_key_t const& free_key() const { return free_key_; } - - /** + */ + static index_dense_gt make( // + metric_t metric, // + index_dense_config_t config = {}, // + vector_key_t free_key = default_free_value()) { + + scalar_kind_t scalar_kind = metric.scalar_kind(); + std::size_t hardware_threads = std::thread::hardware_concurrency(); + + index_dense_gt result; + result.config_ = config; + result.cast_buffer_.resize(hardware_threads * metric.bytes_per_vector()); + result.casts_ = make_casts_(scalar_kind); + result.metric_ = metric; + result.free_key_ = free_key; + + // Fill the thread IDs. + result.available_threads_.resize(hardware_threads); + std::iota(result.available_threads_.begin(), result.available_threads_.end(), 0ul); + + // Available since C11, but only C++17, so we use the C version. + index_t* raw = index_allocator_t{}.allocate(1); + new (raw) index_t(config); + result.typed_ = raw; + return result; + } + + static index_dense_gt make(char const* path, bool view = false) { + index_dense_metadata_result_t meta = index_dense_metadata_from_path(path); + if (!meta) + return {}; + metric_punned_t metric(meta.head.dimensions, meta.head.kind_metric, meta.head.kind_scalar); + index_dense_gt result = make(metric); + if (!result) + return result; + if (view) + result.view(path); + else + result.load(path); + return result; + } + + explicit operator bool() const { return typed_; } + std::size_t connectivity() const { return typed_->connectivity(); } + std::size_t size() const { return typed_->size() - free_keys_.size(); } + std::size_t capacity() const { return typed_->capacity(); } + std::size_t max_level() const noexcept { return typed_->max_level(); } + index_dense_config_t const& config() const { return config_; } + index_limits_t const& limits() const { return typed_->limits(); } + bool multi() const { return config_.multi; } + + // The metric and its properties + metric_t const& metric() const { return metric_; } + void change_metric(metric_t metric) { metric_ = std::move(metric); } + + scalar_kind_t scalar_kind() const noexcept { return metric_.scalar_kind(); } + std::size_t bytes_per_vector() const noexcept { return metric_.bytes_per_vector(); } + std::size_t scalar_words() const noexcept { return metric_.scalar_words(); } + std::size_t dimensions() const noexcept { return metric_.dimensions(); } + + // Fetching and changing search criteria + std::size_t expansion_add() const { return config_.expansion_add; } + std::size_t expansion_search() const { return config_.expansion_search; } + void change_expansion_add(std::size_t n) { config_.expansion_add = n; } + void change_expansion_search(std::size_t n) { config_.expansion_search = n; } + + member_citerator_t cbegin() const { return typed_->cbegin(); } + member_citerator_t cend() const { return typed_->cend(); } + member_citerator_t begin() const { return typed_->begin(); } + member_citerator_t end() const { return typed_->end(); } + member_iterator_t begin() { return typed_->begin(); } + member_iterator_t end() { return typed_->end(); } + + stats_t stats() const { return typed_->stats(); } + stats_t stats(std::size_t level) const { return typed_->stats(level); } + stats_t stats(stats_t* stats_per_level, std::size_t max_level) const { + return typed_->stats(stats_per_level, max_level); + } + + dynamic_allocator_t const& allocator() const { return typed_->dynamic_allocator(); } + vector_key_t const& free_key() const { return free_key_; } + + /** * @brief A relatively accurate lower bound on the amount of memory consumed by the system. * In practice it's error will be below 10%. * * @see `serialized_length` for the length of the binary serialized representation. - */ - std::size_t memory_usage() const { - return // - typed_->memory_usage(0) + // - typed_->tape_allocator().total_wasted() + // - typed_->tape_allocator().total_reserved() + // - vectors_tape_allocator_.total_allocated(); - } - - static constexpr std::size_t any_thread() { return std::numeric_limits::max(); } - static constexpr distance_t infinite_distance() { return std::numeric_limits::max(); } - - struct aggregated_distances_t { - std::size_t count = 0; - distance_t mean = infinite_distance(); - distance_t min = infinite_distance(); - distance_t max = infinite_distance(); - }; - - // clang-format off + */ + std::size_t memory_usage() const { + return // + typed_->memory_usage(0) + // + typed_->tape_allocator().total_wasted() + // + typed_->tape_allocator().total_reserved() + // + vectors_tape_allocator_.total_allocated(); + } + + static constexpr std::size_t any_thread() { return std::numeric_limits::max(); } + static constexpr distance_t infinite_distance() { return std::numeric_limits::max(); } + + struct aggregated_distances_t { + std::size_t count = 0; + distance_t mean = infinite_distance(); + distance_t min = infinite_distance(); + distance_t max = infinite_distance(); + }; + + // clang-format off add_result_t add(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_b1x8); } add_result_t add(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_i8); } add_result_t add(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f16); } add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f32); } add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool force_vector_copy = true) { return add_(key, vector, thread, force_vector_copy, casts_.from_f64); } - search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_b1x8); } - search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_i8); } - search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f16); } - search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f32); } - search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f64); } + search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_b1x8, config_.expansion_search); } + search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_i8, config_.expansion_search); } + search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f16, config_.expansion_search); } + search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f32, config_.expansion_search); } + search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f64, config_.expansion_search); } + + search_result_t ef_search(b1x8_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_b1x8, ef_search); } + search_result_t ef_search(i8_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_i8, ef_search); } + search_result_t ef_search(f16_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f16, ef_search); } + search_result_t ef_search(f32_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f32, ef_search); } + search_result_t ef_search(f64_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from_f64, ef_search); } - search_result_t ef_search(b1x8_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_b1x8, ef_search); } - search_result_t ef_search(i8_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_i8, ef_search); } - search_result_t ef_search(f16_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f16, ef_search); } - search_result_t ef_search(f32_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f32, ef_search); } - search_result_t ef_search(f64_t const* vector, std::size_t wanted, std::size_t ef_search, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, thread, exact, casts_.from_f64, ef_search); } + template search_result_t filtered_search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_b1x8, config_.expansion_search); } + template search_result_t filtered_search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_i8, config_.expansion_search); } + template search_result_t filtered_search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_f16, config_.expansion_search); } + template search_result_t filtered_search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_f32, config_.expansion_search); } + template search_result_t filtered_search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from_f64, config_.expansion_search); } std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_b1x8); } std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to_i8); } @@ -633,888 +645,896 @@ class index_dense_gt { aggregated_distances_t distance_between(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f16); } aggregated_distances_t distance_between(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f32); } aggregated_distances_t distance_between(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to_f64); } - // clang-format on + // clang-format on - /** + /** * @brief Computes the distance between two managed entities. * If either key maps into more than one vector, will aggregate results * exporting the mean, maximum, and minimum values. - */ - aggregated_distances_t distance_between(vector_key_t a, vector_key_t b, std::size_t = any_thread()) const { - shared_lock_t lock(slot_lookup_mutex_); - aggregated_distances_t result; - if (!multi()) { - auto a_it = slot_lookup_.find(key_and_slot_t::any_slot(a)); - auto b_it = slot_lookup_.find(key_and_slot_t::any_slot(b)); - bool a_missing = a_it == slot_lookup_.end(); - bool b_missing = b_it == slot_lookup_.end(); - if (a_missing || b_missing) - return result; - - key_and_slot_t a_key_and_slot = *a_it; - byte_t const* a_vector = vectors_lookup_[a_key_and_slot.slot]; - key_and_slot_t b_key_and_slot = *b_it; - byte_t const* b_vector = vectors_lookup_[b_key_and_slot.slot]; - distance_t a_b_distance = metric_(a_vector, b_vector); - - result.mean = result.min = result.max = a_b_distance; - result.count = 1; - return result; - } - - auto a_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(a)); - auto b_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(b)); - bool a_missing = a_range.first == a_range.second; - bool b_missing = b_range.first == b_range.second; - if (a_missing || b_missing) - return result; - - result.min = std::numeric_limits::max(); - result.max = std::numeric_limits::min(); - result.mean = 0; - result.count = 0; - - while (a_range.first != a_range.second) { - key_and_slot_t a_key_and_slot = *a_range.first; - byte_t const* a_vector = vectors_lookup_[a_key_and_slot.slot]; - while (b_range.first != b_range.second) { - key_and_slot_t b_key_and_slot = *b_range.first; - byte_t const* b_vector = vectors_lookup_[b_key_and_slot.slot]; - distance_t a_b_distance = metric_(a_vector, b_vector); - - result.mean += a_b_distance; - result.min = (std::min)(result.min, a_b_distance); - result.max = (std::max)(result.max, a_b_distance); - result.count++; - - // - ++b_range.first; - } - ++a_range.first; - } - - result.mean /= result.count; - return result; - } - - /** + */ + aggregated_distances_t distance_between(vector_key_t a, vector_key_t b, std::size_t = any_thread()) const { + shared_lock_t lock(slot_lookup_mutex_); + aggregated_distances_t result; + if (!multi()) { + auto a_it = slot_lookup_.find(key_and_slot_t::any_slot(a)); + auto b_it = slot_lookup_.find(key_and_slot_t::any_slot(b)); + bool a_missing = a_it == slot_lookup_.end(); + bool b_missing = b_it == slot_lookup_.end(); + if (a_missing || b_missing) + return result; + + key_and_slot_t a_key_and_slot = *a_it; + byte_t const* a_vector = vectors_lookup_[a_key_and_slot.slot]; + key_and_slot_t b_key_and_slot = *b_it; + byte_t const* b_vector = vectors_lookup_[b_key_and_slot.slot]; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean = result.min = result.max = a_b_distance; + result.count = 1; + return result; + } + + auto a_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(a)); + auto b_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(b)); + bool a_missing = a_range.first == a_range.second; + bool b_missing = b_range.first == b_range.second; + if (a_missing || b_missing) + return result; + + result.min = std::numeric_limits::max(); + result.max = std::numeric_limits::min(); + result.mean = 0; + result.count = 0; + + while (a_range.first != a_range.second) { + key_and_slot_t a_key_and_slot = *a_range.first; + byte_t const* a_vector = vectors_lookup_[a_key_and_slot.slot]; + while (b_range.first != b_range.second) { + key_and_slot_t b_key_and_slot = *b_range.first; + byte_t const* b_vector = vectors_lookup_[b_key_and_slot.slot]; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean += a_b_distance; + result.min = (std::min)(result.min, a_b_distance); + result.max = (std::max)(result.max, a_b_distance); + result.count++; + + // + ++b_range.first; + } + ++a_range.first; + } + + result.mean /= result.count; + return result; + } + + /** * @brief Identifies a node in a given `level`, that is the closest to the `key`. - */ - cluster_result_t cluster(vector_key_t key, std::size_t level, std::size_t thread = any_thread()) const { - - // Check if such `key` is even present. - shared_lock_t slots_lock(slot_lookup_mutex_); - auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); - cluster_result_t result; - if (key_range.first == key_range.second) - return result.failed("Key missing!"); - - index_cluster_config_t cluster_config; - thread_lock_t lock = thread_lock_(thread); - cluster_config.thread = lock.thread_id; - cluster_config.expansion = config_.expansion_search; - metric_proxy_t metric{*this}; - auto allow = [=](member_cref_t const& member) noexcept { return member.key != free_key_; }; - - // Find the closest cluster for any vector under that key. - while (key_range.first != key_range.second) { - key_and_slot_t key_and_slot = *key_range.first; - byte_t const* vector_data = vectors_lookup_[key_and_slot.slot]; - cluster_result_t new_result = typed_->cluster(vector_data, level, metric, cluster_config, allow); - if (!new_result) - return new_result; - if (new_result.cluster.distance < result.cluster.distance) - result = std::move(new_result); - - ++key_range.first; - } - return result; - } - - /** + */ + cluster_result_t cluster(vector_key_t key, std::size_t level, std::size_t thread = any_thread()) const { + + // Check if such `key` is even present. + shared_lock_t slots_lock(slot_lookup_mutex_); + auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + cluster_result_t result; + if (key_range.first == key_range.second) + return result.failed("Key missing!"); + + index_cluster_config_t cluster_config; + thread_lock_t lock = thread_lock_(thread); + cluster_config.thread = lock.thread_id; + cluster_config.expansion = config_.expansion_search; + metric_proxy_t metric{*this}; + auto &free_key_ = this->free_key; + auto allow = [&free_key_](member_cref_t const& member) noexcept { + return member.key != free_key_; + }; + + // Find the closest cluster for any vector under that key. + while (key_range.first != key_range.second) { + key_and_slot_t key_and_slot = *key_range.first; + byte_t const* vector_data = vectors_lookup_[key_and_slot.slot]; + cluster_result_t new_result = typed_->cluster(vector_data, level, metric, cluster_config, allow); + if (!new_result) + return new_result; + if (new_result.cluster.distance < result.cluster.distance) + result = std::move(new_result); + + ++key_range.first; + } + return result; + } + + /** * @brief Reserves memory for the index and the keyed lookup. * @return `true` if the memory reservation was successful, `false` otherwise. - */ - bool reserve(index_limits_t limits) { - { - unique_lock_t lock(slot_lookup_mutex_); - slot_lookup_.reserve(limits.members); - vectors_lookup_.resize(limits.members); - } - return typed_->reserve(limits); - } - - /** + */ + bool reserve(index_limits_t limits) { + { + unique_lock_t lock(slot_lookup_mutex_); + slot_lookup_.reserve(limits.members); + vectors_lookup_.resize(limits.members); + + // During reserve, no insertions may be happening, so we can safely overwrite the whole collection. + std::unique_lock available_threads_lock(available_threads_mutex_); + available_threads_.resize(limits.threads()); + std::iota(available_threads_.begin(), available_threads_.end(), 0ul); + } + return typed_->reserve(limits); + } + + /** * @brief Erases all the vectors from the index. * * Will change `size()` to zero, but will keep the same `capacity()`. * Will keep the number of available threads/contexts the same as it was. - */ - void clear() { - unique_lock_t lookup_lock(slot_lookup_mutex_); - - std::unique_lock free_lock(free_keys_mutex_); - typed_->clear(); - slot_lookup_.clear(); - vectors_lookup_.clear(); - free_keys_.clear(); - vectors_tape_allocator_.reset(); - } - - /** + */ + void clear() { + unique_lock_t lookup_lock(slot_lookup_mutex_); + + std::unique_lock free_lock(free_keys_mutex_); + typed_->clear(); + slot_lookup_.clear(); + vectors_lookup_.clear(); + free_keys_.clear(); + vectors_tape_allocator_.reset(); + } + + /** * @brief Erases all members from index, closing files, and returning RAM to OS. * * Will change both `size()` and `capacity()` to zero. * Will deallocate all threads/contexts. * If the index is memory-mapped - releases the mapping and the descriptor. - */ - void reset() { - unique_lock_t lookup_lock(slot_lookup_mutex_); - - std::unique_lock free_lock(free_keys_mutex_); - std::unique_lock available_threads_lock(available_threads_mutex_); - typed_->reset(); - slot_lookup_.clear(); - vectors_lookup_.clear(); - free_keys_.clear(); - vectors_tape_allocator_.reset(); - - // Reset the thread IDs. - available_threads_.resize(std::thread::hardware_concurrency()); - std::iota(available_threads_.begin(), available_threads_.end(), 0ul); - } - - /** + */ + void reset() { + unique_lock_t lookup_lock(slot_lookup_mutex_); + + std::unique_lock free_lock(free_keys_mutex_); + std::unique_lock available_threads_lock(available_threads_mutex_); + typed_->reset(); + slot_lookup_.clear(); + vectors_lookup_.clear(); + free_keys_.clear(); + vectors_tape_allocator_.reset(); + + // Reset the thread IDs. + available_threads_.resize(std::thread::hardware_concurrency()); + std::iota(available_threads_.begin(), available_threads_.end(), 0ul); + } + + /** * @brief Saves serialized binary index representation to a stream. - */ - template - serialization_result_t save_to_stream(output_callback_at&& output, // - serialization_config_t config = {}, // - progress_at&& progress = {}) const { - - serialization_result_t result; - std::uint64_t matrix_rows = 0; - std::uint64_t matrix_cols = 0; - - // We may not want to put the vectors into the same file - if (!config.exclude_vectors) { - // Save the matrix size - if (!config.use_64_bit_dimensions) { - std::uint32_t dimensions[2]; - dimensions[0] = static_cast(typed_->size()); - dimensions[1] = static_cast(metric_.bytes_per_vector()); - if (!output(&dimensions, sizeof(dimensions))) - return result.failed("Failed to serialize into stream"); - matrix_rows = dimensions[0]; - matrix_cols = dimensions[1]; - } else { - std::uint64_t dimensions[2]; - dimensions[0] = static_cast(typed_->size()); - dimensions[1] = static_cast(metric_.bytes_per_vector()); - if (!output(&dimensions, sizeof(dimensions))) - return result.failed("Failed to serialize into stream"); - matrix_rows = dimensions[0]; - matrix_cols = dimensions[1]; - } - - // Dump the vectors one after another - for (std::uint64_t i = 0; i != matrix_rows; ++i) { - byte_t* vector = vectors_lookup_[i]; - if (!output(vector, matrix_cols)) - return result.failed("Failed to serialize into stream"); - } - } - - // Augment metadata - { - index_dense_head_buffer_t buffer; - std::memset(buffer, 0, sizeof(buffer)); - index_dense_head_t head{buffer}; - std::memcpy(buffer, default_magic(), std::strlen(default_magic())); - - // Describe software version - using version_t = index_dense_head_t::version_t; - head.version_major = static_cast(USEARCH_VERSION_MAJOR); - head.version_minor = static_cast(USEARCH_VERSION_MINOR); - head.version_patch = static_cast(USEARCH_VERSION_PATCH); - - // Describes types used - head.kind_metric = metric_.metric_kind(); - head.kind_scalar = metric_.scalar_kind(); - head.kind_key = unum::usearch::scalar_kind(); - head.kind_compressed_slot = unum::usearch::scalar_kind(); - - head.count_present = size(); - head.count_deleted = typed_->size() - size(); - head.dimensions = dimensions(); - head.multi = multi(); - - if (!output(&buffer, sizeof(buffer))) - return result.failed("Failed to serialize into stream"); - } - - // Save the actual proximity graph - return typed_->save_to_stream(std::forward(output), std::forward(progress)); - } - - /** + */ + template + serialization_result_t save_to_stream(output_callback_at&& output, // + serialization_config_t config = {}, // + progress_at&& progress = {}) const { + + serialization_result_t result; + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + + // We may not want to put the vectors into the same file + if (!config.exclude_vectors) { + // Save the matrix size + if (!config.use_64_bit_dimensions) { + std::uint32_t dimensions[2]; + dimensions[0] = static_cast(typed_->size()); + dimensions[1] = static_cast(metric_.bytes_per_vector()); + if (!output(&dimensions, sizeof(dimensions))) + return result.failed("Failed to serialize into stream"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } else { + std::uint64_t dimensions[2]; + dimensions[0] = static_cast(typed_->size()); + dimensions[1] = static_cast(metric_.bytes_per_vector()); + if (!output(&dimensions, sizeof(dimensions))) + return result.failed("Failed to serialize into stream"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } + + // Dump the vectors one after another + for (std::uint64_t i = 0; i != matrix_rows; ++i) { + byte_t* vector = vectors_lookup_[i]; + if (!output(vector, matrix_cols)) + return result.failed("Failed to serialize into stream"); + } + } + + // Augment metadata + { + index_dense_head_buffer_t buffer; + std::memset(buffer, 0, sizeof(buffer)); + index_dense_head_t head{buffer}; + std::memcpy(buffer, default_magic(), std::strlen(default_magic())); + + // Describe software version + using version_t = index_dense_head_t::version_t; + head.version_major = static_cast(USEARCH_VERSION_MAJOR); + head.version_minor = static_cast(USEARCH_VERSION_MINOR); + head.version_patch = static_cast(USEARCH_VERSION_PATCH); + + // Describes types used + head.kind_metric = metric_.metric_kind(); + head.kind_scalar = metric_.scalar_kind(); + head.kind_key = unum::usearch::scalar_kind(); + head.kind_compressed_slot = unum::usearch::scalar_kind(); + + head.count_present = size(); + head.count_deleted = typed_->size() - size(); + head.dimensions = dimensions(); + head.multi = multi(); + + if (!output(&buffer, sizeof(buffer))) + return result.failed("Failed to serialize into stream"); + } + + // Save the actual proximity graph + return typed_->save_to_stream(std::forward(output), std::forward(progress)); + } + + /** * @brief Estimate the binary length (in bytes) of the serialized index. - */ - std::size_t serialized_length(serialization_config_t config = {}) const noexcept { - std::size_t dimensions_length = 0; - std::size_t matrix_length = 0; - if (!config.exclude_vectors) { - dimensions_length = config.use_64_bit_dimensions ? sizeof(std::uint64_t) * 2 : sizeof(std::uint32_t) * 2; - matrix_length = typed_->size() * metric_.bytes_per_vector(); - } - return dimensions_length + matrix_length + sizeof(index_dense_head_buffer_t) + typed_->serialized_length(); - } - - /** + */ + std::size_t serialized_length(serialization_config_t config = {}) const noexcept { + std::size_t dimensions_length = 0; + std::size_t matrix_length = 0; + if (!config.exclude_vectors) { + dimensions_length = config.use_64_bit_dimensions ? sizeof(std::uint64_t) * 2 : sizeof(std::uint32_t) * 2; + matrix_length = typed_->size() * metric_.bytes_per_vector(); + } + return dimensions_length + matrix_length + sizeof(index_dense_head_buffer_t) + typed_->serialized_length(); + } + + /** * @brief Parses the index from file to RAM. * @param[in] path The path to the file. * @param[in] config Configuration parameters for imports. * @return Outcome descriptor explicitly convertible to boolean. - */ - template - serialization_result_t load_from_stream(input_callback_at&& input, // - serialization_config_t config = {}, // - progress_at&& progress = {}) { - - // Discard all previous memory allocations of `vectors_tape_allocator_` - reset(); - - // Infer the new index size - serialization_result_t result; - std::uint64_t matrix_rows = 0; - std::uint64_t matrix_cols = 0; - - // We may not want to load the vectors from the same file, or allow attaching them afterwards - if (!config.exclude_vectors) { - // Save the matrix size - if (!config.use_64_bit_dimensions) { - std::uint32_t dimensions[2]; - if (!input(&dimensions, sizeof(dimensions))) - return result.failed("Failed to read 32-bit dimensions of the matrix"); - matrix_rows = dimensions[0]; - matrix_cols = dimensions[1]; - } else { - std::uint64_t dimensions[2]; - if (!input(&dimensions, sizeof(dimensions))) - return result.failed("Failed to read 64-bit dimensions of the matrix"); - matrix_rows = dimensions[0]; - matrix_cols = dimensions[1]; - } - // Load the vectors one after another - vectors_lookup_.resize(matrix_rows); - for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) { - byte_t* vector = vectors_tape_allocator_.allocate(matrix_cols); - if (!input(vector, matrix_cols)) - return result.failed("Failed to read vectors"); - vectors_lookup_[slot] = vector; - } - } - - // Load metadata and choose the right metric - { - index_dense_head_buffer_t buffer; - if (!input(buffer, sizeof(buffer))) - return result.failed("Failed to read the index "); - - index_dense_head_t head{buffer}; - if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) - return result.failed("Magic header mismatch - the file isn't an index"); - - // Validate the software version - if (head.version_major != USEARCH_VERSION_MAJOR) - return result.failed("File format may be different, please rebuild"); - - // Check the types used - if (head.kind_key != unum::usearch::scalar_kind()) - return result.failed("Key type doesn't match, consider rebuilding"); - if (head.kind_compressed_slot != unum::usearch::scalar_kind()) - return result.failed("Slot type doesn't match, consider rebuilding"); - - config_.multi = head.multi; - metric_ = metric_t(head.dimensions, head.kind_metric, head.kind_scalar); - cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); - casts_ = make_casts_(head.kind_scalar); - } - - // Pull the actual proximity graph - result = typed_->load_from_stream(std::forward(input), std::forward(progress)); - if (!result) - return result; - if (typed_->size() != static_cast(matrix_rows)) - return result.failed("Index size and the number of vectors doesn't match"); - - reindex_keys_(); - return result; - } - - /** + */ + template + serialization_result_t load_from_stream(input_callback_at&& input, // + serialization_config_t config = {}, // + progress_at&& progress = {}) { + + // Discard all previous memory allocations of `vectors_tape_allocator_` + reset(); + + // Infer the new index size + serialization_result_t result; + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + + // We may not want to load the vectors from the same file, or allow attaching them afterwards + if (!config.exclude_vectors) { + // Save the matrix size + if (!config.use_64_bit_dimensions) { + std::uint32_t dimensions[2]; + if (!input(&dimensions, sizeof(dimensions))) + return result.failed("Failed to read 32-bit dimensions of the matrix"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } else { + std::uint64_t dimensions[2]; + if (!input(&dimensions, sizeof(dimensions))) + return result.failed("Failed to read 64-bit dimensions of the matrix"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } + // Load the vectors one after another + vectors_lookup_.resize(matrix_rows); + for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) { + byte_t* vector = vectors_tape_allocator_.allocate(matrix_cols); + if (!input(vector, matrix_cols)) + return result.failed("Failed to read vectors"); + vectors_lookup_[slot] = vector; + } + } + + // Load metadata and choose the right metric + { + index_dense_head_buffer_t buffer; + if (!input(buffer, sizeof(buffer))) + return result.failed("Failed to read the index "); + + index_dense_head_t head{buffer}; + if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) + return result.failed("Magic header mismatch - the file isn't an index"); + + // Validate the software version + if (head.version_major != USEARCH_VERSION_MAJOR) + return result.failed("File format may be different, please rebuild"); + + // Check the types used + if (head.kind_key != unum::usearch::scalar_kind()) + return result.failed("Key type doesn't match, consider rebuilding"); + if (head.kind_compressed_slot != unum::usearch::scalar_kind()) + return result.failed("Slot type doesn't match, consider rebuilding"); + + config_.multi = head.multi; + metric_ = metric_t::builtin(head.dimensions, head.kind_metric, head.kind_scalar); + cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); + casts_ = make_casts_(head.kind_scalar); + } + + // Pull the actual proximity graph + result = typed_->load_from_stream(std::forward(input), std::forward(progress)); + if (!result) + return result; + if (typed_->size() != static_cast(matrix_rows)) + return result.failed("Index size and the number of vectors doesn't match"); + + reindex_keys_(); + return result; + } + + /** * @brief Parses the index from file, without loading it into RAM. * @param[in] path The path to the file. * @param[in] config Configuration parameters for imports. * @return Outcome descriptor explicitly convertible to boolean. - */ - template - serialization_result_t view(memory_mapped_file_t file, // - std::size_t offset = 0, serialization_config_t config = {}, // - progress_at&& progress = {}) { - - // Discard all previous memory allocations of `vectors_tape_allocator_` - reset(); - - serialization_result_t result = file.open_if_not(); - if (!result) - return result; - - // Infer the new index size - std::uint64_t matrix_rows = 0; - std::uint64_t matrix_cols = 0; - span_punned_t vectors_buffer; - - // We may not want to fetch the vectors from the same file, or allow attaching them afterwards - if (!config.exclude_vectors) { - // Save the matrix size - if (!config.use_64_bit_dimensions) { - std::uint32_t dimensions[2]; - if (file.size() - offset < sizeof(dimensions)) - return result.failed("File is corrupted and lacks matrix dimensions"); - std::memcpy(&dimensions, file.data() + offset, sizeof(dimensions)); - matrix_rows = dimensions[0]; - matrix_cols = dimensions[1]; - offset += sizeof(dimensions); - } else { - std::uint64_t dimensions[2]; - if (file.size() - offset < sizeof(dimensions)) - return result.failed("File is corrupted and lacks matrix dimensions"); - std::memcpy(&dimensions, file.data() + offset, sizeof(dimensions)); - matrix_rows = dimensions[0]; - matrix_cols = dimensions[1]; - offset += sizeof(dimensions); - } - vectors_buffer = {file.data() + offset, static_cast(matrix_rows * matrix_cols)}; - offset += vectors_buffer.size(); - } - - // Load metadata and choose the right metric - { - index_dense_head_buffer_t buffer; - if (file.size() - offset < sizeof(buffer)) - return result.failed("File is corrupted and lacks a header"); - - std::memcpy(buffer, file.data() + offset, sizeof(buffer)); - - index_dense_head_t head{buffer}; - if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) - return result.failed("Magic header mismatch - the file isn't an index"); - - // Validate the software version - if (head.version_major != USEARCH_VERSION_MAJOR) - return result.failed("File format may be different, please rebuild"); - - // Check the types used - if (head.kind_key != unum::usearch::scalar_kind()) - return result.failed("Key type doesn't match, consider rebuilding"); - if (head.kind_compressed_slot != unum::usearch::scalar_kind()) - return result.failed("Slot type doesn't match, consider rebuilding"); - - config_.multi = head.multi; - metric_ = metric_t(head.dimensions, head.kind_metric, head.kind_scalar); - cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); - casts_ = make_casts_(head.kind_scalar); - offset += sizeof(buffer); - } - - // Pull the actual proximity graph - result = typed_->view(std::move(file), offset, std::forward(progress)); - if (!result) - return result; - if (typed_->size() != static_cast(matrix_rows)) - return result.failed("Index size and the number of vectors doesn't match"); - - // Address the vectors - vectors_lookup_.resize(matrix_rows); - if (!config.exclude_vectors) - for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) - vectors_lookup_[slot] = (byte_t*)vectors_buffer.data() + matrix_cols * slot; - - reindex_keys_(); - return result; - } - - /** + */ + template + serialization_result_t view(memory_mapped_file_t file, // + std::size_t offset = 0, serialization_config_t config = {}, // + progress_at&& progress = {}) { + + // Discard all previous memory allocations of `vectors_tape_allocator_` + reset(); + + serialization_result_t result = file.open_if_not(); + if (!result) + return result; + + // Infer the new index size + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + span_punned_t vectors_buffer; + + // We may not want to fetch the vectors from the same file, or allow attaching them afterwards + if (!config.exclude_vectors) { + // Save the matrix size + if (!config.use_64_bit_dimensions) { + std::uint32_t dimensions[2]; + if (file.size() - offset < sizeof(dimensions)) + return result.failed("File is corrupted and lacks matrix dimensions"); + std::memcpy(&dimensions, file.data() + offset, sizeof(dimensions)); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + offset += sizeof(dimensions); + } else { + std::uint64_t dimensions[2]; + if (file.size() - offset < sizeof(dimensions)) + return result.failed("File is corrupted and lacks matrix dimensions"); + std::memcpy(&dimensions, file.data() + offset, sizeof(dimensions)); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + offset += sizeof(dimensions); + } + vectors_buffer = {file.data() + offset, static_cast(matrix_rows * matrix_cols)}; + offset += vectors_buffer.size(); + } + + // Load metadata and choose the right metric + { + index_dense_head_buffer_t buffer; + if (file.size() - offset < sizeof(buffer)) + return result.failed("File is corrupted and lacks a header"); + + std::memcpy(buffer, file.data() + offset, sizeof(buffer)); + + index_dense_head_t head{buffer}; + if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) + return result.failed("Magic header mismatch - the file isn't an index"); + + // Validate the software version + if (head.version_major != USEARCH_VERSION_MAJOR) + return result.failed("File format may be different, please rebuild"); + + // Check the types used + if (head.kind_key != unum::usearch::scalar_kind()) + return result.failed("Key type doesn't match, consider rebuilding"); + if (head.kind_compressed_slot != unum::usearch::scalar_kind()) + return result.failed("Slot type doesn't match, consider rebuilding"); + + config_.multi = head.multi; + metric_ = metric_t::builtin(head.dimensions, head.kind_metric, head.kind_scalar); + cast_buffer_.resize(available_threads_.size() * metric_.bytes_per_vector()); + casts_ = make_casts_(head.kind_scalar); + offset += sizeof(buffer); + } + + // Pull the actual proximity graph + result = typed_->view(std::move(file), offset, std::forward(progress)); + if (!result) + return result; + if (typed_->size() != static_cast(matrix_rows)) + return result.failed("Index size and the number of vectors doesn't match"); + + // Address the vectors + vectors_lookup_.resize(matrix_rows); + if (!config.exclude_vectors) + for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) + vectors_lookup_[slot] = (byte_t*)vectors_buffer.data() + matrix_cols * slot; + + reindex_keys_(); + return result; + } + + /** * @brief Saves the index to a file. * @param[in] path The path to the file. * @param[in] config Configuration parameters for exports. * @return Outcome descriptor explicitly convertible to boolean. - */ - template - serialization_result_t save(output_file_t file, serialization_config_t config = {}, - progress_at&& progress = {}) const { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = save_to_stream( - [&](void const* buffer, std::size_t length) { - io_result = file.write(buffer, length); - return !!io_result; - }, - config, std::forward(progress)); - - if (!stream_result) { - io_result.error.release(); - return stream_result; - } - return io_result; - } - - /** + */ + template + serialization_result_t save(output_file_t file, serialization_config_t config = {}, + progress_at&& progress = {}) const { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void const* buffer, std::size_t length) { + io_result = file.write(buffer, length); + return !!io_result; + }, + config, std::forward(progress)); + + if (!stream_result) { + io_result.error.release(); + return stream_result; + } + return io_result; + } + + /** * @brief Memory-maps the serialized binary index representation from disk, * @b without copying data into RAM, and fetching it on-demand. - */ - template - serialization_result_t save(memory_mapped_file_t file, // - std::size_t offset = 0, // - serialization_config_t config = {}, // - progress_at&& progress = {}) const { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = save_to_stream( - [&](void const* buffer, std::size_t length) { - if (offset + length > file.size()) - return false; - std::memcpy(file.data() + offset, buffer, length); - offset += length; - return true; - }, - config, std::forward(progress)); - - return stream_result; - } - - /** + */ + template + serialization_result_t save(memory_mapped_file_t file, // + std::size_t offset = 0, // + serialization_config_t config = {}, // + progress_at&& progress = {}) const { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void const* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(file.data() + offset, buffer, length); + offset += length; + return true; + }, + config, std::forward(progress)); + + return stream_result; + } + + /** * @brief Parses the index from file to RAM. * @param[in] path The path to the file. * @param[in] config Configuration parameters for imports. * @return Outcome descriptor explicitly convertible to boolean. - */ - template - serialization_result_t load(input_file_t file, serialization_config_t config = {}, progress_at&& progress = {}) { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = load_from_stream( - [&](void* buffer, std::size_t length) { - io_result = file.read(buffer, length); - return !!io_result; - }, - config, std::forward(progress)); - - if (!stream_result) { - io_result.error.release(); - return stream_result; - } - return io_result; - } - - /** + */ + template + serialization_result_t load(input_file_t file, serialization_config_t config = {}, progress_at&& progress = {}) { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + io_result = file.read(buffer, length); + return !!io_result; + }, + config, std::forward(progress)); + + if (!stream_result) { + io_result.error.release(); + return stream_result; + } + return io_result; + } + + /** * @brief Memory-maps the serialized binary index representation from disk, * @b without copying data into RAM, and fetching it on-demand. - */ - template - serialization_result_t load(memory_mapped_file_t file, // - std::size_t offset = 0, // - serialization_config_t config = {}, // - progress_at&& progress = {}) { - - serialization_result_t io_result = file.open_if_not(); - if (!io_result) - return io_result; - - serialization_result_t stream_result = load_from_stream( - [&](void* buffer, std::size_t length) { - if (offset + length > file.size()) - return false; - std::memcpy(buffer, file.data() + offset, length); - offset += length; - return true; - }, - config, std::forward(progress)); - - return stream_result; - } - - template - serialization_result_t save(char const* file_path, // - serialization_config_t config = {}, // - progress_at&& progress = {}) const { - return save(output_file_t(file_path), config, std::forward(progress)); - } - - template - serialization_result_t load(char const* file_path, // - serialization_config_t config = {}, // - progress_at&& progress = {}) { - return load(input_file_t(file_path), config, std::forward(progress)); - } - - /** + */ + template + serialization_result_t load(memory_mapped_file_t file, // + std::size_t offset = 0, // + serialization_config_t config = {}, // + progress_at&& progress = {}) { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(buffer, file.data() + offset, length); + offset += length; + return true; + }, + config, std::forward(progress)); + + return stream_result; + } + + template + serialization_result_t save(char const* file_path, // + serialization_config_t config = {}, // + progress_at&& progress = {}) const { + return save(output_file_t(file_path), config, std::forward(progress)); + } + + template + serialization_result_t load(char const* file_path, // + serialization_config_t config = {}, // + progress_at&& progress = {}) { + return load(input_file_t(file_path), config, std::forward(progress)); + } + + /** * @brief Checks if a vector with specified key is present. * @return `true` if the key is present in the index, `false` otherwise. - */ - bool contains(vector_key_t key) const { - shared_lock_t lock(slot_lookup_mutex_); - return slot_lookup_.contains(key_and_slot_t::any_slot(key)); - } + */ + bool contains(vector_key_t key) const { + shared_lock_t lock(slot_lookup_mutex_); + return slot_lookup_.contains(key_and_slot_t::any_slot(key)); + } - /** + /** * @brief Count the number of vectors with specified key present. * @return Zero if nothing is found, a positive integer otherwise. - */ - std::size_t count(vector_key_t key) const { - shared_lock_t lock(slot_lookup_mutex_); - return slot_lookup_.count(key_and_slot_t::any_slot(key)); - } - - struct labeling_result_t { - error_t error{}; - std::size_t completed{}; - - explicit operator bool() const noexcept { return !error; } - labeling_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - - /** + */ + std::size_t count(vector_key_t key) const { + shared_lock_t lock(slot_lookup_mutex_); + return slot_lookup_.count(key_and_slot_t::any_slot(key)); + } + + struct labeling_result_t { + error_t error{}; + std::size_t completed{}; + + explicit operator bool() const noexcept { return !error; } + labeling_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** * @brief Removes an entry with the specified key from the index. * @param[in] key The key of the entry to remove. * @return The ::labeling_result_t indicating the result of the removal operation. * If the removal was successful, `result.completed` will be `true`. * If the key was not found in the index, `result.completed` will be `false`. * If an error occurred during the removal operation, `result.error` will contain an error message. - */ - labeling_result_t remove(vector_key_t key) { - labeling_result_t result; - - unique_lock_t lookup_lock(slot_lookup_mutex_); - auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); - if (matching_slots.first == matching_slots.second) - return result; - - // Grow the removed entries ring, if needed - std::size_t matching_count = std::distance(matching_slots.first, matching_slots.second); - std::unique_lock free_lock(free_keys_mutex_); - if (!free_keys_.reserve(free_keys_.size() + matching_count)) - return result.failed("Can't allocate memory for a free-list"); - - // A removed entry would be: - // - present in `free_keys_` - // - missing in the `slot_lookup_` - // - marked in the `typed_` index with a `free_key_` - for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) { - compressed_slot_t slot = (*slots_it).slot; - free_keys_.push(slot); - typed_->at(slot).key = free_key_; - } - slot_lookup_.erase(key); - result.completed = matching_count; - - return result; - } - - /** + */ + labeling_result_t remove(vector_key_t key) { + labeling_result_t result; + + unique_lock_t lookup_lock(slot_lookup_mutex_); + auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + if (matching_slots.first == matching_slots.second) + return result; + + // Grow the removed entries ring, if needed + std::size_t matching_count = std::distance(matching_slots.first, matching_slots.second); + std::unique_lock free_lock(free_keys_mutex_); + if (!free_keys_.reserve(free_keys_.size() + matching_count)) + return result.failed("Can't allocate memory for a free-list"); + + // A removed entry would be: + // - present in `free_keys_` + // - missing in the `slot_lookup_` + // - marked in the `typed_` index with a `free_key_` + for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) { + compressed_slot_t slot = (*slots_it).slot; + free_keys_.push(slot); + typed_->at(slot).key = free_key_; + } + slot_lookup_.erase(key); + result.completed = matching_count; + + return result; + } + + /** * @brief Removes multiple entries with the specified keys from the index. * @param[in] keys_begin The beginning of the keys range. * @param[in] keys_end The ending of the keys range. * @return The ::labeling_result_t indicating the result of the removal operation. * `result.completed` will contain the number of keys that were successfully removed. * `result.error` will contain an error message if an error occurred during the removal operation. - */ - template - labeling_result_t remove(keys_iterator_at keys_begin, keys_iterator_at keys_end) { - - labeling_result_t result; - unique_lock_t lookup_lock(slot_lookup_mutex_); - std::unique_lock free_lock(free_keys_mutex_); - // Grow the removed entries ring, if needed - std::size_t matching_count = 0; - for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) - matching_count += slot_lookup_.count(key_and_slot_t::any_slot(*keys_it)); - - if (!free_keys_.reserve(free_keys_.size() + matching_count)) - return result.failed("Can't allocate memory for a free-list"); - - // Remove them one-by-one - for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) { - vector_key_t key = *keys_it; - auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); - // A removed entry would be: - // - present in `free_keys_` - // - missing in the `slot_lookup_` - // - marked in the `typed_` index with a `free_key_` - matching_count = 0; - for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) { - compressed_slot_t slot = (*slots_it).slot; - free_keys_.push(slot); - typed_->at(slot).key = free_key_; - ++matching_count; - } - - slot_lookup_.erase(key); - result.completed += matching_count; - } - - return result; - } - - /** + */ + template + labeling_result_t remove(keys_iterator_at keys_begin, keys_iterator_at keys_end) { + + labeling_result_t result; + unique_lock_t lookup_lock(slot_lookup_mutex_); + std::unique_lock free_lock(free_keys_mutex_); + // Grow the removed entries ring, if needed + std::size_t matching_count = 0; + for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) + matching_count += slot_lookup_.count(key_and_slot_t::any_slot(*keys_it)); + + if (!free_keys_.reserve(free_keys_.size() + matching_count)) + return result.failed("Can't allocate memory for a free-list"); + + // Remove them one-by-one + for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) { + vector_key_t key = *keys_it; + auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + // A removed entry would be: + // - present in `free_keys_` + // - missing in the `slot_lookup_` + // - marked in the `typed_` index with a `free_key_` + matching_count = 0; + for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) { + compressed_slot_t slot = (*slots_it).slot; + free_keys_.push(slot); + typed_->at(slot).key = free_key_; + ++matching_count; + } + + slot_lookup_.erase(key); + result.completed += matching_count; + } + + return result; + } + + /** * @brief Renames an entry with the specified key to a new key. * @param[in] from The current key of the entry to rename. * @param[in] to The new key to assign to the entry. * @return The ::labeling_result_t indicating the result of the rename operation. * If the rename was successful, `result.completed` will be `true`. * If the entry with the current key was not found, `result.completed` will be `false`. - */ - labeling_result_t rename(vector_key_t from, vector_key_t to) { - labeling_result_t result; - unique_lock_t lookup_lock(slot_lookup_mutex_); - - if (!multi() && slot_lookup_.contains(key_and_slot_t::any_slot(to))) - return result.failed("Renaming impossible, the key is already in use"); - - // The `from` may map to multiple entries - while (true) { - key_and_slot_t key_and_slot_removed; - if (!slot_lookup_.pop_first(key_and_slot_t::any_slot(from), key_and_slot_removed)) - break; - - key_and_slot_t key_and_slot_replacing{to, key_and_slot_removed.slot}; - slot_lookup_.try_emplace(key_and_slot_replacing); // This can't fail - typed_->at(key_and_slot_removed.slot).key = to; - ++result.completed; - } - - return result; - } - - /** + */ + labeling_result_t rename(vector_key_t from, vector_key_t to) { + labeling_result_t result; + unique_lock_t lookup_lock(slot_lookup_mutex_); + + if (!multi() && slot_lookup_.contains(key_and_slot_t::any_slot(to))) + return result.failed("Renaming impossible, the key is already in use"); + + // The `from` may map to multiple entries + while (true) { + key_and_slot_t key_and_slot_removed; + if (!slot_lookup_.pop_first(key_and_slot_t::any_slot(from), key_and_slot_removed)) + break; + + key_and_slot_t key_and_slot_replacing{to, key_and_slot_removed.slot}; + slot_lookup_.try_emplace(key_and_slot_replacing); // This can't fail + typed_->at(key_and_slot_removed.slot).key = to; + ++result.completed; + } + + return result; + } + + /** * @brief Exports a range of keys for the vectors present in the index. * @param[out] keys Pointer to the array where the keys will be exported. * @param[in] offset The number of keys to skip. Useful for pagination. * @param[in] limit The maximum number of keys to export, that can fit in ::keys. - */ - void export_keys(vector_key_t* keys, std::size_t offset, std::size_t limit) const { - shared_lock_t lock(slot_lookup_mutex_); - offset = (std::min)(offset, slot_lookup_.size()); - slot_lookup_.for_each([&](key_and_slot_t const& key_and_slot) { - if (offset) - // Skip the first `offset` entries - --offset; - else if (limit) { - *keys = key_and_slot.key; - ++keys; - --limit; - } - }); - } - - struct copy_result_t { - index_dense_gt index; - error_t error; - - explicit operator bool() const noexcept { return !error; } - copy_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - - /** + */ + void export_keys(vector_key_t* keys, std::size_t offset, std::size_t limit) const { + shared_lock_t lock(slot_lookup_mutex_); + offset = (std::min)(offset, slot_lookup_.size()); + slot_lookup_.for_each([&](key_and_slot_t const& key_and_slot) { + if (offset) + // Skip the first `offset` entries + --offset; + else if (limit) { + *keys = key_and_slot.key; + ++keys; + --limit; + } + }); + } + + struct copy_result_t { + index_dense_gt index; + error_t error; + + explicit operator bool() const noexcept { return !error; } + copy_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** * @brief Copies the ::index_dense_gt @b with all the data in it. * @param config The copy configuration (optional). * @return A copy of the ::index_dense_gt instance. - */ - copy_result_t copy(index_dense_copy_config_t config = {}) const { - copy_result_t result = fork(); - if (!result) - return result; - - auto typed_result = typed_->copy(config); - if (!typed_result) - return result.failed(std::move(typed_result.error)); - - // Export the free (removed) slot numbers - index_dense_gt& copy = result.index; - if (!copy.free_keys_.reserve(free_keys_.size())) - return result.failed(std::move(typed_result.error)); - for (std::size_t i = 0; i != free_keys_.size(); ++i) - copy.free_keys_.push(free_keys_[i]); - - // Allocate buffers and move the vectors themselves - if (!config.force_vector_copy && copy.config_.exclude_vectors) - copy.vectors_lookup_ = vectors_lookup_; - else { - copy.vectors_lookup_.resize(vectors_lookup_.size()); - for (std::size_t slot = 0; slot != vectors_lookup_.size(); ++slot) - copy.vectors_lookup_[slot] = copy.vectors_tape_allocator_.allocate(copy.metric_.bytes_per_vector()); - if (std::count(copy.vectors_lookup_.begin(), copy.vectors_lookup_.end(), nullptr)) - return result.failed("Out of memory!"); - for (std::size_t slot = 0; slot != vectors_lookup_.size(); ++slot) - std::memcpy(copy.vectors_lookup_[slot], vectors_lookup_[slot], metric_.bytes_per_vector()); - } - - copy.slot_lookup_ = slot_lookup_; - *copy.typed_ = std::move(typed_result.index); - return result; - } - - /** + */ + copy_result_t copy(index_dense_copy_config_t config = {}) const { + copy_result_t result = fork(); + if (!result) + return result; + + auto typed_result = typed_->copy(config); + if (!typed_result) + return result.failed(std::move(typed_result.error)); + + // Export the free (removed) slot numbers + index_dense_gt& copy = result.index; + if (!copy.free_keys_.reserve(free_keys_.size())) + return result.failed(std::move(typed_result.error)); + for (std::size_t i = 0; i != free_keys_.size(); ++i) + copy.free_keys_.push(free_keys_[i]); + + // Allocate buffers and move the vectors themselves + if (!config.force_vector_copy && copy.config_.exclude_vectors) + copy.vectors_lookup_ = vectors_lookup_; + else { + copy.vectors_lookup_.resize(vectors_lookup_.size()); + for (std::size_t slot = 0; slot != vectors_lookup_.size(); ++slot) + copy.vectors_lookup_[slot] = copy.vectors_tape_allocator_.allocate(copy.metric_.bytes_per_vector()); + if (std::count(copy.vectors_lookup_.begin(), copy.vectors_lookup_.end(), nullptr)) + return result.failed("Out of memory!"); + for (std::size_t slot = 0; slot != vectors_lookup_.size(); ++slot) + std::memcpy(copy.vectors_lookup_[slot], vectors_lookup_[slot], metric_.bytes_per_vector()); + } + + copy.slot_lookup_ = slot_lookup_; + *copy.typed_ = std::move(typed_result.index); + return result; + } + + /** * @brief Copies the ::index_dense_gt model @b without any data. * @return A similarly configured ::index_dense_gt instance. - */ - copy_result_t fork() const { - copy_result_t result; - index_dense_gt& other = result.index; - - other.config_ = config_; - other.cast_buffer_ = cast_buffer_; - other.casts_ = casts_; - - other.metric_ = metric_; - other.available_threads_ = available_threads_; - other.free_key_ = free_key_; - - index_t* raw = index_allocator_t{}.allocate(1); - if (!raw) - return result.failed("Can't allocate the index"); - - new (raw) index_t(config()); - other.typed_ = raw; - return result; - } - - struct compaction_result_t { - error_t error{}; - std::size_t pruned_edges{}; - - explicit operator bool() const noexcept { return !error; } - compaction_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - - /** + */ + copy_result_t fork() const { + copy_result_t result; + index_dense_gt& other = result.index; + + other.config_ = config_; + other.cast_buffer_ = cast_buffer_; + other.casts_ = casts_; + + other.metric_ = metric_; + other.available_threads_ = available_threads_; + other.free_key_ = free_key_; + + index_t* raw = index_allocator_t{}.allocate(1); + if (!raw) + return result.failed("Can't allocate the index"); + + new (raw) index_t(config()); + other.typed_ = raw; + return result; + } + + struct compaction_result_t { + error_t error{}; + std::size_t pruned_edges{}; + + explicit operator bool() const noexcept { return !error; } + compaction_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** * @brief Performs compaction on the index, pruning links to removed entries. * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. * @return The ::compaction_result_t indicating the result of the compaction operation. * `result.pruned_edges` will contain the number of edges that were removed. * `result.error` will contain an error message if an error occurred during the compaction operation. - */ - template - compaction_result_t isolate(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { - compaction_result_t result; - std::atomic pruned_edges; - auto disallow = [&](member_cref_t const& member) noexcept { - bool freed = member.key == free_key_; - pruned_edges += freed; - return freed; - }; - typed_->isolate(disallow, std::forward(executor), std::forward(progress)); - result.pruned_edges = pruned_edges; - return result; - } - - class values_proxy_t { - index_dense_gt const* index_; - - public: - values_proxy_t(index_dense_gt const& index) noexcept : index_(&index) {} - byte_t const* operator[](compressed_slot_t slot) const noexcept { return index_->vectors_lookup_[slot]; } - byte_t const* operator[](member_citerator_t it) const noexcept { return index_->vectors_lookup_[get_slot(it)]; } - }; - - /** + */ + template + compaction_result_t isolate(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + compaction_result_t result; + std::atomic pruned_edges; + auto disallow = [&](member_cref_t const& member) noexcept { + bool freed = member.key == free_key_; + pruned_edges += freed; + return freed; + }; + typed_->isolate(disallow, std::forward(executor), std::forward(progress)); + result.pruned_edges = pruned_edges; + return result; + } + + class values_proxy_t { + index_dense_gt const* index_; + + public: + values_proxy_t(index_dense_gt const& index) noexcept : index_(&index) {} + byte_t const* operator[](compressed_slot_t slot) const noexcept { return index_->vectors_lookup_[slot]; } + byte_t const* operator[](member_citerator_t it) const noexcept { return index_->vectors_lookup_[get_slot(it)]; } + }; + + /** * @brief Performs compaction on the index, pruning links to removed entries. * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. * @return The ::compaction_result_t indicating the result of the compaction operation. * `result.pruned_edges` will contain the number of edges that were removed. * `result.error` will contain an error message if an error occurred during the compaction operation. - */ - template - compaction_result_t compact(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { - compaction_result_t result; - - std::vector new_vectors_lookup(vectors_lookup_.size()); - vectors_tape_allocator_t new_vectors_allocator; - - auto track_slot_change = [&](vector_key_t, compressed_slot_t old_slot, compressed_slot_t new_slot) { - byte_t* new_vector = new_vectors_allocator.allocate(metric_.bytes_per_vector()); - byte_t* old_vector = vectors_lookup_[old_slot]; - std::memcpy(new_vector, old_vector, metric_.bytes_per_vector()); - new_vectors_lookup[new_slot] = new_vector; - }; - typed_->compact(values_proxy_t{*this}, metric_proxy_t{*this}, track_slot_change, - std::forward(executor), std::forward(progress)); - vectors_lookup_ = std::move(new_vectors_lookup); - vectors_tape_allocator_ = std::move(new_vectors_allocator); - return result; - } - - template < // - typename man_to_woman_at = dummy_key_to_key_mapping_t, // - typename woman_to_man_at = dummy_key_to_key_mapping_t, // - typename executor_at = dummy_executor_t, // - typename progress_at = dummy_progress_t // - > - join_result_t join( // - index_dense_gt const& women, // - index_join_config_t config = {}, // - man_to_woman_at&& man_to_woman = man_to_woman_at{}, // - woman_to_man_at&& woman_to_man = woman_to_man_at{}, // - executor_at&& executor = executor_at{}, // - progress_at&& progress = progress_at{}) const { - - index_dense_gt const& men = *this; - return unum::usearch::join( // - *men.typed_, *women.typed_, // - values_proxy_t{men}, values_proxy_t{women}, // - metric_proxy_t{men}, metric_proxy_t{women}, // - config, // - std::forward(man_to_woman), // - std::forward(woman_to_man), // - std::forward(executor), // - std::forward(progress)); - } - - struct clustering_result_t { - error_t error{}; - std::size_t clusters{}; - std::size_t visited_members{}; - std::size_t computed_distances{}; - - explicit operator bool() const noexcept { return !error; } - clustering_result_t failed(error_t message) noexcept { - error = std::move(message); - return std::move(*this); - } - }; - - /** + */ + template + compaction_result_t compact(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + compaction_result_t result; + + std::vector new_vectors_lookup(vectors_lookup_.size()); + vectors_tape_allocator_t new_vectors_allocator; + + auto track_slot_change = [&](vector_key_t, compressed_slot_t old_slot, compressed_slot_t new_slot) { + byte_t* new_vector = new_vectors_allocator.allocate(metric_.bytes_per_vector()); + byte_t* old_vector = vectors_lookup_[old_slot]; + std::memcpy(new_vector, old_vector, metric_.bytes_per_vector()); + new_vectors_lookup[new_slot] = new_vector; + }; + typed_->compact(values_proxy_t{*this}, metric_proxy_t{*this}, track_slot_change, + std::forward(executor), std::forward(progress)); + vectors_lookup_ = std::move(new_vectors_lookup); + vectors_tape_allocator_ = std::move(new_vectors_allocator); + return result; + } + + template < // + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + join_result_t join( // + index_dense_gt const& women, // + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) const { + + index_dense_gt const& men = *this; + return unum::usearch::join( // + *men.typed_, *women.typed_, // + values_proxy_t{men}, values_proxy_t{women}, // + metric_proxy_t{men}, metric_proxy_t{women}, // + config, // + std::forward(man_to_woman), // + std::forward(woman_to_man), // + std::forward(executor), // + std::forward(progress)); + } + + struct clustering_result_t { + error_t error{}; + std::size_t clusters{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + + explicit operator bool() const noexcept { return !error; } + clustering_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** * @brief Implements clustering, classifying the given objects (vectors of member keys) * into a given number of clusters. * @@ -1526,465 +1546,453 @@ class index_dense_gt { * * @param[out] cluster_keys Pointer to the array where the cluster keys will be exported. * @param[out] cluster_distances Pointer to the array where the distances to those centroids will be exported. - */ - template < // - typename queries_iterator_at, // - typename executor_at = dummy_executor_t, // - typename progress_at = dummy_progress_t // - > - clustering_result_t cluster( // - queries_iterator_at queries_begin, // - queries_iterator_at queries_end, // - index_dense_clustering_config_t config, // - vector_key_t* cluster_keys, // - distance_t* cluster_distances, // - executor_at&& executor = executor_at{}, // - progress_at&& progress = progress_at{}) { - - std::size_t const queries_count = queries_end - queries_begin; - - // Find the first level (top -> down) that has enough nodes to exceed `config.min_clusters`. - std::size_t level = max_level(); - if (config.min_clusters) { - for (; level > 1; --level) { - if (stats(level).nodes > config.min_clusters) - break; - } - } else - level = 1, config.max_clusters = stats(1).nodes, config.min_clusters = 2; - - clustering_result_t result; - if (max_level() < 2) - return result.failed("Index too small to cluster!"); - - // A structure used to track the popularity of a specific cluster - struct cluster_t { - vector_key_t centroid; - vector_key_t merged_into; - std::size_t popularity; - byte_t* vector; - }; - - auto centroid_id = [](cluster_t const& a, cluster_t const& b) { return a.centroid < b.centroid; }; - auto higher_popularity = [](cluster_t const& a, cluster_t const& b) { return a.popularity > b.popularity; }; - - std::atomic visited_members(0); - std::atomic computed_distances(0); - std::atomic atomic_error{nullptr}; - - using dynamic_allocator_traits_t = std::allocator_traits; - using clusters_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; - buffer_gt clusters(queries_count); - if (!clusters) - return result.failed("Out of memory!"); - - map_to_clusters: - // Concurrently perform search until a certain depth - executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { - auto result = cluster(queries_begin[query_idx], level, thread_idx); - if (!result) { - atomic_error = result.error.release(); - return false; - } - - cluster_keys[query_idx] = result.cluster.member.key; - cluster_distances[query_idx] = result.cluster.distance; - - // Export in case we need to refine afterwards - clusters[query_idx].centroid = result.cluster.member.key; - clusters[query_idx].vector = vectors_lookup_[result.cluster.member.slot]; - clusters[query_idx].merged_into = free_key(); - clusters[query_idx].popularity = 1; - - visited_members += result.visited_members; - computed_distances += result.computed_distances; - return true; - }); - - if (atomic_error) - return result.failed(atomic_error.load()); - - // Now once we have identified the closest clusters, - // we can try reducing their quantity, refining - std::sort(clusters.begin(), clusters.end(), centroid_id); - - // Transform into run-length encoding, computing the number of unique clusters - std::size_t unique_clusters = 0; - { - std::size_t last_idx = 0; - for (std::size_t current_idx = 1; current_idx != clusters.size(); ++current_idx) { - if (clusters[last_idx].centroid == clusters[current_idx].centroid) { - clusters[last_idx].popularity++; - } else { - last_idx++; - clusters[last_idx] = clusters[current_idx]; - } - } - unique_clusters = last_idx + 1; - } - - // In some cases the queries may be co-located, all mapping into the same cluster on that - // level. In that case we refine the granularity and dive deeper into clusters: - if (unique_clusters < config.min_clusters && level > 1) { - level--; - goto map_to_clusters; - } - - std::sort(clusters.data(), clusters.data() + unique_clusters, higher_popularity); - - // If clusters are too numerous, merge the ones that are too close to each other. - std::size_t merge_cycles = 0; - merge_nearby_clusters: - if (unique_clusters > config.max_clusters) { - - cluster_t& merge_source = clusters[unique_clusters - 1]; - std::size_t merge_target_idx = 0; - distance_t merge_distance = std::numeric_limits::max(); - - for (std::size_t candidate_idx = 0; candidate_idx + 1 < unique_clusters; ++candidate_idx) { - distance_t distance = metric_(merge_source.vector, clusters[candidate_idx].vector); - if (distance < merge_distance) { - merge_distance = distance; - merge_target_idx = candidate_idx; - } - } - - merge_source.merged_into = clusters[merge_target_idx].centroid; - clusters[merge_target_idx].popularity += exchange(merge_source.popularity, 0); - - // The target object may have to be swapped a few times to get to optimal position. - while (merge_target_idx && - clusters[merge_target_idx - 1].popularity < clusters[merge_target_idx].popularity) - std::swap(clusters[merge_target_idx - 1], clusters[merge_target_idx]), --merge_target_idx; - - unique_clusters--; - merge_cycles++; - goto merge_nearby_clusters; - } - - // Replace evicted clusters - if (merge_cycles) { - // Sort dropped clusters by name to accelerate future lookups - auto clusters_end = clusters.data() + config.max_clusters + merge_cycles; - std::sort(clusters.data(), clusters_end, centroid_id); - - executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { - vector_key_t& cluster_key = cluster_keys[query_idx]; - distance_t& cluster_distance = cluster_distances[query_idx]; - - // Recursively trace replacements of that cluster - while (true) { - // To avoid implementing heterogeneous comparisons, lets wrap the `cluster_key` - cluster_t updated_cluster; - updated_cluster.centroid = cluster_key; - updated_cluster = *std::lower_bound(clusters.data(), clusters_end, updated_cluster, centroid_id); - if (updated_cluster.merged_into == free_key()) - break; - cluster_key = updated_cluster.merged_into; - } - - cluster_distance = distance_between(cluster_key, queries_begin[query_idx], thread_idx).mean; - return true; - }); - } - - result.computed_distances = computed_distances; - result.visited_members = visited_members; - - (void)progress; - return result; - } - -private: - struct thread_lock_t { - index_dense_gt const& parent; - std::size_t thread_id; - bool engaged; - - ~thread_lock_t() { - if (engaged) - parent.thread_unlock_(thread_id); - } - }; - - thread_lock_t thread_lock_(std::size_t thread_id) const { - if (thread_id != any_thread()) - return {*this, thread_id, false}; - - available_threads_mutex_.lock(); - thread_id = available_threads_.back(); - available_threads_.pop_back(); - available_threads_mutex_.unlock(); - return {*this, thread_id, true}; - } - - void thread_unlock_(std::size_t thread_id) const { - available_threads_mutex_.lock(); - available_threads_.push_back(thread_id); - available_threads_mutex_.unlock(); - } - - template - add_result_t add_( // - vector_key_t key, scalar_at const* vector, // - std::size_t thread, bool force_vector_copy, cast_t const& cast) { - - if (!multi() && contains(key)) - return add_result_t{}.failed("Duplicate keys not allowed in high-level wrappers"); - - // Cast the vector, if needed for compatibility with `metric_` - thread_lock_t lock = thread_lock_(thread); - bool copy_vector = !config_.exclude_vectors || force_vector_copy; - byte_t const* vector_data = reinterpret_cast(vector); - { - byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; - bool casted = cast(vector_data, dimensions(), casted_data); - if (casted) - vector_data = casted_data, copy_vector = true; - } - - // Check if there are some removed entries, whose nodes we can reuse - compressed_slot_t free_slot = default_free_value(); - { - std::unique_lock lock(free_keys_mutex_); - free_keys_.try_pop(free_slot); - } - - // Perform the insertion or the update - bool reuse_node = free_slot != default_free_value(); - auto on_success = [&](member_ref_t member) { - unique_lock_t slot_lock(slot_lookup_mutex_); - slot_lookup_.try_emplace(key_and_slot_t{key, static_cast(member.slot)}); - if (copy_vector) { - if (!reuse_node) - vectors_lookup_[member.slot] = vectors_tape_allocator_.allocate(metric_.bytes_per_vector()); - std::memcpy(vectors_lookup_[member.slot], vector_data, metric_.bytes_per_vector()); - } else - vectors_lookup_[member.slot] = (byte_t*)vector_data; - }; - - index_update_config_t update_config; - update_config.thread = lock.thread_id; - update_config.expansion = config_.expansion_add; - - metric_proxy_t metric{*this}; - return reuse_node // - ? typed_->update(typed_->iterator_at(free_slot), key, vector_data, metric, update_config, on_success) - : typed_->add(key, vector_data, metric, update_config, on_success); - } - - template - search_result_t search_( // - scalar_at const* vector, std::size_t wanted, // - std::size_t thread, bool exact, cast_t const& cast) const { - - // Cast the vector, if needed for compatibility with `metric_` - thread_lock_t lock = thread_lock_(thread); - byte_t const* vector_data = reinterpret_cast(vector); - { - byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; - bool casted = cast(vector_data, dimensions(), casted_data); - if (casted) - vector_data = casted_data; - } - - index_search_config_t search_config; - search_config.thread = lock.thread_id; - search_config.expansion = config_.expansion_search; - search_config.exact = exact; - - auto allow = [=](member_cref_t const& member) noexcept { return member.key != free_key_; }; - return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); - } - - template - search_result_t search_( // - scalar_at const* vector, std::size_t wanted, // - std::size_t thread, bool exact, cast_t const& cast, std::size_t ef_search) const { - - // Cast the vector, if needed for compatibility with `metric_` - thread_lock_t lock = thread_lock_(thread); - byte_t const* vector_data = reinterpret_cast(vector); - { - byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; - bool casted = cast(vector_data, dimensions(), casted_data); - if (casted) - vector_data = casted_data; - } - - index_search_config_t search_config; - search_config.thread = lock.thread_id; - search_config.expansion = ef_search; - search_config.exact = exact; - - auto allow = [=](member_cref_t const& member) noexcept { return member.key != free_key_; }; - return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); - } - - template - cluster_result_t cluster_( // - scalar_at const* vector, std::size_t level, // - std::size_t thread, cast_t const& cast) const { - - // Cast the vector, if needed for compatibility with `metric_` - thread_lock_t lock = thread_lock_(thread); - byte_t const* vector_data = reinterpret_cast(vector); - { - byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; - bool casted = cast(vector_data, dimensions(), casted_data); - if (casted) - vector_data = casted_data; - } - - index_cluster_config_t cluster_config; - cluster_config.thread = lock.thread_id; - cluster_config.expansion = config_.expansion_search; - - auto allow = [=](member_cref_t const& member) noexcept { return member.key != free_key_; }; - return typed_->cluster(vector_data, level, metric_proxy_t{*this}, cluster_config, allow); - } - - template - aggregated_distances_t distance_between_( // - vector_key_t key, scalar_at const* vector, // - std::size_t thread, cast_t const& cast) const { - - // Cast the vector, if needed for compatibility with `metric_` - thread_lock_t lock = thread_lock_(thread); - byte_t const* vector_data = reinterpret_cast(vector); - { - byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; - bool casted = cast(vector_data, dimensions(), casted_data); - if (casted) - vector_data = casted_data; - } - - // Check if such `key` is even present. - shared_lock_t slots_lock(slot_lookup_mutex_); - auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); - aggregated_distances_t result; - if (key_range.first == key_range.second) - return result; - - result.min = std::numeric_limits::max(); - result.max = std::numeric_limits::min(); - result.mean = 0; - result.count = 0; - - while (key_range.first != key_range.second) { - key_and_slot_t key_and_slot = *key_range.first; - byte_t const* a_vector = vectors_lookup_[key_and_slot.slot]; - byte_t const* b_vector = vector_data; - distance_t a_b_distance = metric_(a_vector, b_vector); - - result.mean += a_b_distance; - result.min = (std::min)(result.min, a_b_distance); - result.max = (std::max)(result.max, a_b_distance); - result.count++; - - // - ++key_range.first; - } - - result.mean /= result.count; - return result; - } - - void reindex_keys_() { - - // Estimate number of entries first - std::size_t count_total = typed_->size(); - std::size_t count_removed = 0; - for (std::size_t i = 0; i != count_total; ++i) { - member_cref_t member = typed_->at(i); - count_removed += member.key == free_key_; - } - - if (!count_removed && !config_.enable_key_lookups) - return; - - // Pull entries from the underlying `typed_` into either - // into `slot_lookup_`, or `free_keys_` if they are unused. - unique_lock_t lock(slot_lookup_mutex_); - slot_lookup_.clear(); - if (config_.enable_key_lookups) - slot_lookup_.reserve(count_total - count_removed); - free_keys_.clear(); - free_keys_.reserve(count_removed); - for (std::size_t i = 0; i != typed_->size(); ++i) { - member_cref_t member = typed_->at(i); - if (member.key == free_key_) - free_keys_.push(static_cast(i)); - else if (config_.enable_key_lookups) - slot_lookup_.try_emplace(key_and_slot_t{vector_key_t(member.key), static_cast(i)}); - } - } - - template - std::size_t get_(vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit, cast_t const& cast) const { - - if (!multi()) { - compressed_slot_t slot; - // Find the matching ID - { - shared_lock_t lock(slot_lookup_mutex_); - auto it = slot_lookup_.find(key_and_slot_t::any_slot(key)); - if (it == slot_lookup_.end()) - return false; - slot = (*it).slot; - } - // Export the entry - byte_t const* punned_vector = reinterpret_cast(vectors_lookup_[slot]); - bool casted = cast(punned_vector, dimensions(), (byte_t*)reconstructed); - if (!casted) - std::memcpy(reconstructed, punned_vector, metric_.bytes_per_vector()); - return true; - } else { - shared_lock_t lock(slot_lookup_mutex_); - auto equal_range_pair = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); - std::size_t count_exported = 0; - for (auto begin = equal_range_pair.first; - begin != equal_range_pair.second && count_exported != vectors_limit; ++begin, ++count_exported) { - // - compressed_slot_t slot = (*begin).slot; - byte_t const* punned_vector = reinterpret_cast(vectors_lookup_[slot]); - byte_t* reconstructed_vector = (byte_t*)reconstructed + metric_.bytes_per_vector() * count_exported; - bool casted = cast(punned_vector, dimensions(), reconstructed_vector); - if (!casted) - std::memcpy(reconstructed_vector, punned_vector, metric_.bytes_per_vector()); - } - return count_exported; - } - } - - template static casts_t make_casts_() { - casts_t result; - - result.from_b1x8 = cast_gt{}; - result.from_i8 = cast_gt{}; - result.from_f16 = cast_gt{}; - result.from_f32 = cast_gt{}; - result.from_f64 = cast_gt{}; - - result.to_b1x8 = cast_gt{}; - result.to_i8 = cast_gt{}; - result.to_f16 = cast_gt{}; - result.to_f32 = cast_gt{}; - result.to_f64 = cast_gt{}; - - return result; - } - - static casts_t make_casts_(scalar_kind_t scalar_kind) { - switch (scalar_kind) { - case scalar_kind_t::f64_k: return make_casts_(); - case scalar_kind_t::f32_k: return make_casts_(); - case scalar_kind_t::f16_k: return make_casts_(); - case scalar_kind_t::i8_k: return make_casts_(); - case scalar_kind_t::b1x8_k: return make_casts_(); - default: return {}; - } - } + */ + template < // + typename queries_iterator_at, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + clustering_result_t cluster( // + queries_iterator_at queries_begin, // + queries_iterator_at queries_end, // + index_dense_clustering_config_t config, // + vector_key_t* cluster_keys, // + distance_t* cluster_distances, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) { + + std::size_t const queries_count = queries_end - queries_begin; + + // Find the first level (top -> down) that has enough nodes to exceed `config.min_clusters`. + std::size_t level = max_level(); + if (config.min_clusters) { + for (; level > 1; --level) { + if (stats(level).nodes > config.min_clusters) + break; + } + } else + level = 1, config.max_clusters = stats(1).nodes, config.min_clusters = 2; + + clustering_result_t result; + if (max_level() < 2) + return result.failed("Index too small to cluster!"); + + // A structure used to track the popularity of a specific cluster + struct cluster_t { + vector_key_t centroid; + vector_key_t merged_into; + std::size_t popularity; + byte_t* vector; + }; + + auto centroid_id = [](cluster_t const& a, cluster_t const& b) { return a.centroid < b.centroid; }; + auto higher_popularity = [](cluster_t const& a, cluster_t const& b) { return a.popularity > b.popularity; }; + + std::atomic visited_members(0); + std::atomic computed_distances(0); + std::atomic atomic_error{nullptr}; + + using dynamic_allocator_traits_t = std::allocator_traits; + using clusters_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt clusters(queries_count); + if (!clusters) + return result.failed("Out of memory!"); + + map_to_clusters: + // Concurrently perform search until a certain depth + executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { + auto result = cluster(queries_begin[query_idx], level, thread_idx); + if (!result) { + atomic_error = result.error.release(); + return false; + } + + cluster_keys[query_idx] = result.cluster.member.key; + cluster_distances[query_idx] = result.cluster.distance; + + // Export in case we need to refine afterwards + clusters[query_idx].centroid = result.cluster.member.key; + clusters[query_idx].vector = vectors_lookup_[result.cluster.member.slot]; + clusters[query_idx].merged_into = free_key(); + clusters[query_idx].popularity = 1; + + visited_members += result.visited_members; + computed_distances += result.computed_distances; + return true; + }); + + if (atomic_error) + return result.failed(atomic_error.load()); + + // Now once we have identified the closest clusters, + // we can try reducing their quantity, refining + std::sort(clusters.begin(), clusters.end(), centroid_id); + + // Transform into run-length encoding, computing the number of unique clusters + std::size_t unique_clusters = 0; + { + std::size_t last_idx = 0; + for (std::size_t current_idx = 1; current_idx != clusters.size(); ++current_idx) { + if (clusters[last_idx].centroid == clusters[current_idx].centroid) { + clusters[last_idx].popularity++; + } else { + last_idx++; + clusters[last_idx] = clusters[current_idx]; + } + } + unique_clusters = last_idx + 1; + } + + // In some cases the queries may be co-located, all mapping into the same cluster on that + // level. In that case we refine the granularity and dive deeper into clusters: + if (unique_clusters < config.min_clusters && level > 1) { + level--; + goto map_to_clusters; + } + + std::sort(clusters.data(), clusters.data() + unique_clusters, higher_popularity); + + // If clusters are too numerous, merge the ones that are too close to each other. + std::size_t merge_cycles = 0; + merge_nearby_clusters: + if (unique_clusters > config.max_clusters) { + + cluster_t& merge_source = clusters[unique_clusters - 1]; + std::size_t merge_target_idx = 0; + distance_t merge_distance = std::numeric_limits::max(); + + for (std::size_t candidate_idx = 0; candidate_idx + 1 < unique_clusters; ++candidate_idx) { + distance_t distance = metric_(merge_source.vector, clusters[candidate_idx].vector); + if (distance < merge_distance) { + merge_distance = distance; + merge_target_idx = candidate_idx; + } + } + + merge_source.merged_into = clusters[merge_target_idx].centroid; + clusters[merge_target_idx].popularity += exchange(merge_source.popularity, 0); + + // The target object may have to be swapped a few times to get to optimal position. + while (merge_target_idx && + clusters[merge_target_idx - 1].popularity < clusters[merge_target_idx].popularity) + std::swap(clusters[merge_target_idx - 1], clusters[merge_target_idx]), --merge_target_idx; + + unique_clusters--; + merge_cycles++; + goto merge_nearby_clusters; + } + + // Replace evicted clusters + if (merge_cycles) { + // Sort dropped clusters by name to accelerate future lookups + auto clusters_end = clusters.data() + config.max_clusters + merge_cycles; + std::sort(clusters.data(), clusters_end, centroid_id); + + executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { + vector_key_t& cluster_key = cluster_keys[query_idx]; + distance_t& cluster_distance = cluster_distances[query_idx]; + + // Recursively trace replacements of that cluster + while (true) { + // To avoid implementing heterogeneous comparisons, lets wrap the `cluster_key` + cluster_t updated_cluster; + updated_cluster.centroid = cluster_key; + updated_cluster = *std::lower_bound(clusters.data(), clusters_end, updated_cluster, centroid_id); + if (updated_cluster.merged_into == free_key()) + break; + cluster_key = updated_cluster.merged_into; + } + + cluster_distance = distance_between(cluster_key, queries_begin[query_idx], thread_idx).mean; + return true; + }); + } + + result.computed_distances = computed_distances; + result.visited_members = visited_members; + + (void)progress; + return result; + } + + private: + struct thread_lock_t { + index_dense_gt const& parent; + std::size_t thread_id; + bool engaged; + + ~thread_lock_t() { + if (engaged) + parent.thread_unlock_(thread_id); + } + }; + + thread_lock_t thread_lock_(std::size_t thread_id) const { + if (thread_id != any_thread()) + return {*this, thread_id, false}; + + available_threads_mutex_.lock(); + thread_id = available_threads_.back(); + available_threads_.pop_back(); + available_threads_mutex_.unlock(); + return {*this, thread_id, true}; + } + + void thread_unlock_(std::size_t thread_id) const { + available_threads_mutex_.lock(); + available_threads_.push_back(thread_id); + available_threads_mutex_.unlock(); + } + + template + add_result_t add_( // + vector_key_t key, scalar_at const* vector, // + std::size_t thread, bool force_vector_copy, cast_t const& cast) { + + if (!multi() && contains(key)) + return add_result_t{}.failed("Duplicate keys not allowed in high-level wrappers"); + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + bool copy_vector = !config_.exclude_vectors || force_vector_copy; + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data, copy_vector = true; + } + + // Check if there are some removed entries, whose nodes we can reuse + compressed_slot_t free_slot = default_free_value(); + { + std::unique_lock lock(free_keys_mutex_); + free_keys_.try_pop(free_slot); + } + + // Perform the insertion or the update + bool reuse_node = free_slot != default_free_value(); + auto on_success = [&](member_ref_t member) { + unique_lock_t slot_lock(slot_lookup_mutex_); + slot_lookup_.try_emplace(key_and_slot_t{key, static_cast(member.slot)}); + if (copy_vector) { + if (!reuse_node) + vectors_lookup_[member.slot] = vectors_tape_allocator_.allocate(metric_.bytes_per_vector()); + std::memcpy(vectors_lookup_[member.slot], vector_data, metric_.bytes_per_vector()); + } else + vectors_lookup_[member.slot] = (byte_t*)vector_data; + }; + + index_update_config_t update_config; + update_config.thread = lock.thread_id; + update_config.expansion = config_.expansion_add; + + metric_proxy_t metric{*this}; + return reuse_node // + ? typed_->update(typed_->iterator_at(free_slot), key, vector_data, metric, update_config, on_success) + : typed_->add(key, vector_data, metric, update_config, on_success); + } + + template + search_result_t search_(scalar_at const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread, + bool exact, cast_t const& cast, std::size_t expansion_search) const { + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + index_search_config_t search_config; + search_config.thread = lock.thread_id; + search_config.expansion = expansion_search; + search_config.exact = exact; + + auto &free_key_ = this->free_key_; + if (std::is_same::type, dummy_predicate_t>::value) { + auto allow = [&free_key_](member_cref_t const& member) noexcept { + return member.key != free_key_; + }; + return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + } else { + auto allow = [&free_key_, &predicate](member_cref_t const& member) noexcept { + return member.key != free_key_ && predicate(member.key); + }; + return typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + } + } + + template + cluster_result_t cluster_( // + scalar_at const* vector, std::size_t level, // + std::size_t thread, cast_t const& cast) const { + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + index_cluster_config_t cluster_config; + cluster_config.thread = lock.thread_id; + cluster_config.expansion = config_.expansion_search; + + auto &free_key_ = this->free_key_; + auto allow = [&free_key_](member_cref_t const& member) noexcept { + return member.key != free_key_; + }; + return typed_->cluster(vector_data, level, metric_proxy_t{*this}, cluster_config, allow); + } + + template + aggregated_distances_t distance_between_( // + vector_key_t key, scalar_at const* vector, // + std::size_t thread, cast_t const& cast) const { + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + // Check if such `key` is even present. + shared_lock_t slots_lock(slot_lookup_mutex_); + auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + aggregated_distances_t result; + if (key_range.first == key_range.second) + return result; + + result.min = std::numeric_limits::max(); + result.max = std::numeric_limits::min(); + result.mean = 0; + result.count = 0; + + while (key_range.first != key_range.second) { + key_and_slot_t key_and_slot = *key_range.first; + byte_t const* a_vector = vectors_lookup_[key_and_slot.slot]; + byte_t const* b_vector = vector_data; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean += a_b_distance; + result.min = (std::min)(result.min, a_b_distance); + result.max = (std::max)(result.max, a_b_distance); + result.count++; + + // + ++key_range.first; + } + + result.mean /= result.count; + return result; + } + + void reindex_keys_() { + + // Estimate number of entries first + std::size_t count_total = typed_->size(); + std::size_t count_removed = 0; + for (std::size_t i = 0; i != count_total; ++i) { + member_cref_t member = typed_->at(i); + count_removed += member.key == free_key_; + } + + if (!count_removed && !config_.enable_key_lookups) + return; + + // Pull entries from the underlying `typed_` into either + // into `slot_lookup_`, or `free_keys_` if they are unused. + unique_lock_t lock(slot_lookup_mutex_); + slot_lookup_.clear(); + if (config_.enable_key_lookups) + slot_lookup_.reserve(count_total - count_removed); + free_keys_.clear(); + free_keys_.reserve(count_removed); + for (std::size_t i = 0; i != typed_->size(); ++i) { + member_cref_t member = typed_->at(i); + if (member.key == free_key_) + free_keys_.push(static_cast(i)); + else if (config_.enable_key_lookups) + slot_lookup_.try_emplace(key_and_slot_t{vector_key_t(member.key), static_cast(i)}); + } + } + + template + std::size_t get_(vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit, cast_t const& cast) const { + + if (!multi()) { + compressed_slot_t slot; + // Find the matching ID + { + shared_lock_t lock(slot_lookup_mutex_); + auto it = slot_lookup_.find(key_and_slot_t::any_slot(key)); + if (it == slot_lookup_.end()) + return false; + slot = (*it).slot; + } + // Export the entry + byte_t const* punned_vector = reinterpret_cast(vectors_lookup_[slot]); + bool casted = cast(punned_vector, dimensions(), (byte_t*)reconstructed); + if (!casted) + std::memcpy(reconstructed, punned_vector, metric_.bytes_per_vector()); + return true; + } else { + shared_lock_t lock(slot_lookup_mutex_); + auto equal_range_pair = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + std::size_t count_exported = 0; + for (auto begin = equal_range_pair.first; + begin != equal_range_pair.second && count_exported != vectors_limit; ++begin, ++count_exported) { + // + compressed_slot_t slot = (*begin).slot; + byte_t const* punned_vector = reinterpret_cast(vectors_lookup_[slot]); + byte_t* reconstructed_vector = (byte_t*)reconstructed + metric_.bytes_per_vector() * count_exported; + bool casted = cast(punned_vector, dimensions(), reconstructed_vector); + if (!casted) + std::memcpy(reconstructed_vector, punned_vector, metric_.bytes_per_vector()); + } + return count_exported; + } + } + + template static casts_t make_casts_() { + casts_t result; + + result.from_b1x8 = cast_gt{}; + result.from_i8 = cast_gt{}; + result.from_f16 = cast_gt{}; + result.from_f32 = cast_gt{}; + result.from_f64 = cast_gt{}; + + result.to_b1x8 = cast_gt{}; + result.to_i8 = cast_gt{}; + result.to_f16 = cast_gt{}; + result.to_f32 = cast_gt{}; + result.to_f64 = cast_gt{}; + + return result; + } + + static casts_t make_casts_(scalar_kind_t scalar_kind) { + switch (scalar_kind) { + case scalar_kind_t::f64_k: return make_casts_(); + case scalar_kind_t::f32_k: return make_casts_(); + case scalar_kind_t::f16_k: return make_casts_(); + case scalar_kind_t::i8_k: return make_casts_(); + case scalar_kind_t::b1x8_k: return make_casts_(); + default: return {}; + } + } }; using index_dense_t = index_dense_gt<>; @@ -2022,13 +2030,13 @@ static join_result_t join( // executor_at&& executor = executor_at{}, // progress_at&& progress = progress_at{}) noexcept { - return men.join( // - women, config, // - std::forward(woman_to_man), // - std::forward(man_to_woman), // - std::forward(executor), // - std::forward(progress)); + return men.join( // + women, config, // + std::forward(woman_to_man), // + std::forward(man_to_woman), // + std::forward(executor), // + std::forward(progress)); } } // namespace usearch -} // namespace unum \ No newline at end of file +} // namespace unum diff --git a/src/include/usearch/index_plugins.hpp b/src/include/usearch/index_plugins.hpp index 57e79a8..1ae5fe9 100644 --- a/src/include/usearch/index_plugins.hpp +++ b/src/include/usearch/index_plugins.hpp @@ -45,27 +45,19 @@ #endif #if USEARCH_USE_SIMSIMD - // Propagate the `f16` settings +#if !defined(SIMSIMD_NATIVE_F16) #define SIMSIMD_NATIVE_F16 !USEARCH_USE_FP16LIB - -#if !defined(SIMSIMD_TARGET_X86_AVX512) && defined(USEARCH_DEFINED_LINUX) -#define SIMSIMD_TARGET_X86_AVX512 1 -#endif - -#if !defined(SIMSIMD_TARGET_ARM_SVE) && defined(USEARCH_DEFINED_LINUX) -#define SIMSIMD_TARGET_ARM_SVE 1 -#endif - -#if !defined(SIMSIMD_TARGET_X86_AVX2) && (defined(USEARCH_DEFINED_LINUX) || defined(USEARCH_DEFINED_APPLE)) -#define SIMSIMD_TARGET_X86_AVX2 1 -#endif - -#if !defined(SIMSIMD_TARGET_ARM_NEON) && (defined(USEARCH_DEFINED_LINUX) || defined(USEARCH_DEFINED_APPLE)) -#define SIMSIMD_TARGET_ARM_NEON 1 #endif - +#define SIMSIMD_DYNAMIC_DISPATCH 0 +// No problem, if some of the functions are unused or undefined +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma warning(push) +#pragma warning(disable : 4101) #include +#pragma warning(pop) +#pragma GCC diagnostic pop #endif namespace unum { @@ -75,14 +67,16 @@ using u40_t = uint40_t; enum b1x8_t : unsigned char {}; struct uuid_t { - std::uint8_t octets[16]; + std::uint8_t octets[16]; }; class f16_bits_t; class i8_converted_t; #if !USEARCH_USE_FP16LIB -#if defined(USEARCH_DEFINED_ARM) +#if USEARCH_USE_SIMSIMD +using f16_native_t = simsimd_f16_t; +#elif defined(USEARCH_DEFINED_ARM) using f16_native_t = __fp16; #else using f16_native_t = _Float16; @@ -107,84 +101,84 @@ using i16_t = std::int16_t; using i8_t = std::int8_t; enum class metric_kind_t : std::uint8_t { - unknown_k = 0, - // Classics: - ip_k = 'i', - cos_k = 'c', - l2sq_k = 'e', - - // Custom: - pearson_k = 'p', - haversine_k = 'h', - divergence_k = 'd', - - // Sets: - jaccard_k = 'j', - hamming_k = 'b', - tanimoto_k = 't', - sorensen_k = 's', + unknown_k = 0, + // Classics: + ip_k = 'i', + cos_k = 'c', + l2sq_k = 'e', + + // Custom: + pearson_k = 'p', + haversine_k = 'h', + divergence_k = 'd', + + // Sets: + jaccard_k = 'j', + hamming_k = 'b', + tanimoto_k = 't', + sorensen_k = 's', }; enum class scalar_kind_t : std::uint8_t { - unknown_k = 0, - // Custom: - b1x8_k, - u40_k, - uuid_k, - // Common: - f64_k, - f32_k, - f16_k, - f8_k, - // Common Integral: - u64_k, - u32_k, - u16_k, - u8_k, - i64_k, - i32_k, - i16_k, - i8_k, + unknown_k = 0, + // Custom: + b1x8_k = 1, + u40_k = 2, + uuid_k = 3, + // Common: + f64_k = 10, + f32_k = 11, + f16_k = 12, + f8_k = 13, + // Common Integral: + u64_k = 14, + u32_k = 15, + u16_k = 16, + u8_k = 17, + i64_k = 20, + i32_k = 21, + i16_k = 22, + i8_k = 23, }; enum class prefetching_kind_t { - none_k, - cpu_k, - io_uring_k, + none_k, + cpu_k, + io_uring_k, }; template scalar_kind_t scalar_kind() noexcept { - if (std::is_same()) - return scalar_kind_t::b1x8_k; - if (std::is_same()) - return scalar_kind_t::u40_k; - if (std::is_same()) - return scalar_kind_t::uuid_k; - if (std::is_same()) - return scalar_kind_t::f64_k; - if (std::is_same()) - return scalar_kind_t::f32_k; - if (std::is_same()) - return scalar_kind_t::f16_k; - if (std::is_same()) - return scalar_kind_t::i8_k; - if (std::is_same()) - return scalar_kind_t::u64_k; - if (std::is_same()) - return scalar_kind_t::u32_k; - if (std::is_same()) - return scalar_kind_t::u16_k; - if (std::is_same()) - return scalar_kind_t::u8_k; - if (std::is_same()) - return scalar_kind_t::i64_k; - if (std::is_same()) - return scalar_kind_t::i32_k; - if (std::is_same()) - return scalar_kind_t::i16_k; - if (std::is_same()) - return scalar_kind_t::i8_k; - return scalar_kind_t::unknown_k; + if (std::is_same()) + return scalar_kind_t::b1x8_k; + if (std::is_same()) + return scalar_kind_t::u40_k; + if (std::is_same()) + return scalar_kind_t::uuid_k; + if (std::is_same()) + return scalar_kind_t::f64_k; + if (std::is_same()) + return scalar_kind_t::f32_k; + if (std::is_same()) + return scalar_kind_t::f16_k; + if (std::is_same()) + return scalar_kind_t::i8_k; + if (std::is_same()) + return scalar_kind_t::u64_k; + if (std::is_same()) + return scalar_kind_t::u32_k; + if (std::is_same()) + return scalar_kind_t::u16_k; + if (std::is_same()) + return scalar_kind_t::u8_k; + if (std::is_same()) + return scalar_kind_t::i64_k; + if (std::is_same()) + return scalar_kind_t::i32_k; + if (std::is_same()) + return scalar_kind_t::i16_k; + if (std::is_same()) + return scalar_kind_t::i8_k; + return scalar_kind_t::unknown_k; } template at angle_to_radians(at angle) noexcept { return angle * at(3.14159265358979323846) / at(180); } @@ -192,133 +186,133 @@ template at angle_to_radians(at angle) noexcept { return angle * a template at square(at value) noexcept { return value * value; } template inline at clamp(at v, at lo, at hi, compare_at comp) noexcept { - return comp(v, lo) ? lo : comp(hi, v) ? hi : v; + return comp(v, lo) ? lo : comp(hi, v) ? hi : v; } template inline at clamp(at v, at lo, at hi) noexcept { - return usearch::clamp(v, lo, hi, std::less{}); + return usearch::clamp(v, lo, hi, std::less{}); } inline bool str_equals(char const* begin, std::size_t len, char const* other_begin) noexcept { - std::size_t other_len = std::strlen(other_begin); - return len == other_len && std::strncmp(begin, other_begin, len) == 0; + std::size_t other_len = std::strlen(other_begin); + return len == other_len && std::strncmp(begin, other_begin, len) == 0; } inline std::size_t bits_per_scalar(scalar_kind_t scalar_kind) noexcept { - switch (scalar_kind) { - case scalar_kind_t::f64_k: return 64; - case scalar_kind_t::f32_k: return 32; - case scalar_kind_t::f16_k: return 16; - case scalar_kind_t::i8_k: return 8; - case scalar_kind_t::b1x8_k: return 1; - default: return 0; - } + switch (scalar_kind) { + case scalar_kind_t::f64_k: return 64; + case scalar_kind_t::f32_k: return 32; + case scalar_kind_t::f16_k: return 16; + case scalar_kind_t::i8_k: return 8; + case scalar_kind_t::b1x8_k: return 1; + default: return 0; + } } inline std::size_t bits_per_scalar_word(scalar_kind_t scalar_kind) noexcept { - switch (scalar_kind) { - case scalar_kind_t::f64_k: return 64; - case scalar_kind_t::f32_k: return 32; - case scalar_kind_t::f16_k: return 16; - case scalar_kind_t::i8_k: return 8; - case scalar_kind_t::b1x8_k: return 8; - default: return 0; - } + switch (scalar_kind) { + case scalar_kind_t::f64_k: return 64; + case scalar_kind_t::f32_k: return 32; + case scalar_kind_t::f16_k: return 16; + case scalar_kind_t::i8_k: return 8; + case scalar_kind_t::b1x8_k: return 8; + default: return 0; + } } inline char const* scalar_kind_name(scalar_kind_t scalar_kind) noexcept { - switch (scalar_kind) { - case scalar_kind_t::f32_k: return "f32"; - case scalar_kind_t::f16_k: return "f16"; - case scalar_kind_t::f64_k: return "f64"; - case scalar_kind_t::i8_k: return "i8"; - case scalar_kind_t::b1x8_k: return "b1x8"; - default: return ""; - } + switch (scalar_kind) { + case scalar_kind_t::f32_k: return "f32"; + case scalar_kind_t::f16_k: return "f16"; + case scalar_kind_t::f64_k: return "f64"; + case scalar_kind_t::i8_k: return "i8"; + case scalar_kind_t::b1x8_k: return "b1x8"; + default: return ""; + } } inline char const* metric_kind_name(metric_kind_t metric) noexcept { - switch (metric) { - case metric_kind_t::unknown_k: return "unknown"; - case metric_kind_t::ip_k: return "ip"; - case metric_kind_t::cos_k: return "cos"; - case metric_kind_t::l2sq_k: return "l2sq"; - case metric_kind_t::pearson_k: return "pearson"; - case metric_kind_t::haversine_k: return "haversine"; - case metric_kind_t::divergence_k: return "divergence"; - case metric_kind_t::jaccard_k: return "jaccard"; - case metric_kind_t::hamming_k: return "hamming"; - case metric_kind_t::tanimoto_k: return "tanimoto"; - case metric_kind_t::sorensen_k: return "sorensen"; - } - return ""; + switch (metric) { + case metric_kind_t::unknown_k: return "unknown"; + case metric_kind_t::ip_k: return "ip"; + case metric_kind_t::cos_k: return "cos"; + case metric_kind_t::l2sq_k: return "l2sq"; + case metric_kind_t::pearson_k: return "pearson"; + case metric_kind_t::haversine_k: return "haversine"; + case metric_kind_t::divergence_k: return "divergence"; + case metric_kind_t::jaccard_k: return "jaccard"; + case metric_kind_t::hamming_k: return "hamming"; + case metric_kind_t::tanimoto_k: return "tanimoto"; + case metric_kind_t::sorensen_k: return "sorensen"; + } + return ""; } inline expected_gt scalar_kind_from_name(char const* name, std::size_t len) { - expected_gt parsed; - if (str_equals(name, len, "f32")) - parsed.result = scalar_kind_t::f32_k; - else if (str_equals(name, len, "f64")) - parsed.result = scalar_kind_t::f64_k; - else if (str_equals(name, len, "f16")) - parsed.result = scalar_kind_t::f16_k; - else if (str_equals(name, len, "i8")) - parsed.result = scalar_kind_t::i8_k; - else - parsed.failed("Unknown type, choose: f32, f16, f64, i8"); - return parsed; + expected_gt parsed; + if (str_equals(name, len, "f32")) + parsed.result = scalar_kind_t::f32_k; + else if (str_equals(name, len, "f64")) + parsed.result = scalar_kind_t::f64_k; + else if (str_equals(name, len, "f16")) + parsed.result = scalar_kind_t::f16_k; + else if (str_equals(name, len, "i8")) + parsed.result = scalar_kind_t::i8_k; + else + parsed.failed("Unknown type, choose: f32, f16, f64, i8"); + return parsed; } inline expected_gt scalar_kind_from_name(char const* name) { - return scalar_kind_from_name(name, std::strlen(name)); + return scalar_kind_from_name(name, std::strlen(name)); } inline expected_gt metric_from_name(char const* name, std::size_t len) { - expected_gt parsed; - if (str_equals(name, len, "l2sq") || str_equals(name, len, "euclidean_sq")) { - parsed.result = metric_kind_t::l2sq_k; - } else if (str_equals(name, len, "ip") || str_equals(name, len, "inner") || str_equals(name, len, "dot")) { - parsed.result = metric_kind_t::ip_k; - } else if (str_equals(name, len, "cos") || str_equals(name, len, "angular")) { - parsed.result = metric_kind_t::cos_k; - } else if (str_equals(name, len, "haversine")) { - parsed.result = metric_kind_t::haversine_k; - } else if (str_equals(name, len, "divergence")) { - parsed.result = metric_kind_t::divergence_k; - } else if (str_equals(name, len, "pearson")) { - parsed.result = metric_kind_t::pearson_k; - } else if (str_equals(name, len, "hamming")) { - parsed.result = metric_kind_t::hamming_k; - } else if (str_equals(name, len, "tanimoto")) { - parsed.result = metric_kind_t::tanimoto_k; - } else if (str_equals(name, len, "sorensen")) { - parsed.result = metric_kind_t::sorensen_k; - } else - parsed.failed("Unknown distance, choose: l2sq, ip, cos, haversine, divergence, jaccard, pearson, hamming, " - "tanimoto, sorensen"); - return parsed; + expected_gt parsed; + if (str_equals(name, len, "l2sq") || str_equals(name, len, "euclidean_sq")) { + parsed.result = metric_kind_t::l2sq_k; + } else if (str_equals(name, len, "ip") || str_equals(name, len, "inner") || str_equals(name, len, "dot")) { + parsed.result = metric_kind_t::ip_k; + } else if (str_equals(name, len, "cos") || str_equals(name, len, "angular")) { + parsed.result = metric_kind_t::cos_k; + } else if (str_equals(name, len, "haversine")) { + parsed.result = metric_kind_t::haversine_k; + } else if (str_equals(name, len, "divergence")) { + parsed.result = metric_kind_t::divergence_k; + } else if (str_equals(name, len, "pearson")) { + parsed.result = metric_kind_t::pearson_k; + } else if (str_equals(name, len, "hamming")) { + parsed.result = metric_kind_t::hamming_k; + } else if (str_equals(name, len, "tanimoto")) { + parsed.result = metric_kind_t::tanimoto_k; + } else if (str_equals(name, len, "sorensen")) { + parsed.result = metric_kind_t::sorensen_k; + } else + parsed.failed("Unknown distance, choose: l2sq, ip, cos, haversine, divergence, jaccard, pearson, hamming, " + "tanimoto, sorensen"); + return parsed; } inline expected_gt metric_from_name(char const* name) { - return metric_from_name(name, std::strlen(name)); + return metric_from_name(name, std::strlen(name)); } inline float f16_to_f32(std::uint16_t u16) noexcept { #if !USEARCH_USE_FP16LIB - f16_native_t f16; - std::memcpy(&f16, &u16, sizeof(std::uint16_t)); - return float(f16); + f16_native_t f16; + std::memcpy(&f16, &u16, sizeof(std::uint16_t)); + return float(f16); #else - return fp16_ieee_to_fp32_value(u16); + return fp16_ieee_to_fp32_value(u16); #endif } inline std::uint16_t f32_to_f16(float f32) noexcept { #if !USEARCH_USE_FP16LIB - f16_native_t f16 = f16_native_t(f32); - std::uint16_t u16; - std::memcpy(&u16, &f16, sizeof(std::uint16_t)); - return u16; + f16_native_t f16 = f16_native_t(f32); + std::uint16_t u16; + std::memcpy(&u16, &f16, sizeof(std::uint16_t)); + return u16; #else - return fp16_ieee_from_fp32_value(f32); + return fp16_ieee_from_fp32_value(f32); #endif } @@ -328,55 +322,55 @@ inline std::uint16_t f32_to_f16(float f32) noexcept { * agnostic in-software implementation. */ class f16_bits_t { - std::uint16_t uint16_{}; - -public: - inline f16_bits_t() noexcept : uint16_(0) {} - inline f16_bits_t(f16_bits_t&&) = default; - inline f16_bits_t& operator=(f16_bits_t&&) = default; - inline f16_bits_t(f16_bits_t const&) = default; - inline f16_bits_t& operator=(f16_bits_t const&) = default; - - inline operator float() const noexcept { return f16_to_f32(uint16_); } - inline explicit operator bool() const noexcept { return f16_to_f32(uint16_) > 0.5f; } - - inline f16_bits_t(i8_converted_t) noexcept; - inline f16_bits_t(bool v) noexcept : uint16_(f32_to_f16(v)) {} - inline f16_bits_t(float v) noexcept : uint16_(f32_to_f16(v)) {} - inline f16_bits_t(double v) noexcept : uint16_(f32_to_f16(v)) {} - - inline f16_bits_t operator+(f16_bits_t other) const noexcept { return {float(*this) + float(other)}; } - inline f16_bits_t operator-(f16_bits_t other) const noexcept { return {float(*this) - float(other)}; } - inline f16_bits_t operator*(f16_bits_t other) const noexcept { return {float(*this) * float(other)}; } - inline f16_bits_t operator/(f16_bits_t other) const noexcept { return {float(*this) / float(other)}; } - inline f16_bits_t operator+(float other) const noexcept { return {float(*this) + other}; } - inline f16_bits_t operator-(float other) const noexcept { return {float(*this) - other}; } - inline f16_bits_t operator*(float other) const noexcept { return {float(*this) * other}; } - inline f16_bits_t operator/(float other) const noexcept { return {float(*this) / other}; } - inline f16_bits_t operator+(double other) const noexcept { return {float(*this) + other}; } - inline f16_bits_t operator-(double other) const noexcept { return {float(*this) - other}; } - inline f16_bits_t operator*(double other) const noexcept { return {float(*this) * other}; } - inline f16_bits_t operator/(double other) const noexcept { return {float(*this) / other}; } - - inline f16_bits_t& operator+=(float v) noexcept { - uint16_ = f32_to_f16(v + f16_to_f32(uint16_)); - return *this; - } - - inline f16_bits_t& operator-=(float v) noexcept { - uint16_ = f32_to_f16(v - f16_to_f32(uint16_)); - return *this; - } - - inline f16_bits_t& operator*=(float v) noexcept { - uint16_ = f32_to_f16(v * f16_to_f32(uint16_)); - return *this; - } - - inline f16_bits_t& operator/=(float v) noexcept { - uint16_ = f32_to_f16(v / f16_to_f32(uint16_)); - return *this; - } + std::uint16_t uint16_{}; + + public: + inline f16_bits_t() noexcept : uint16_(0) {} + inline f16_bits_t(f16_bits_t&&) = default; + inline f16_bits_t& operator=(f16_bits_t&&) = default; + inline f16_bits_t(f16_bits_t const&) = default; + inline f16_bits_t& operator=(f16_bits_t const&) = default; + + inline operator float() const noexcept { return f16_to_f32(uint16_); } + inline explicit operator bool() const noexcept { return f16_to_f32(uint16_) > 0.5f; } + + inline f16_bits_t(i8_converted_t) noexcept; + inline f16_bits_t(bool v) noexcept : uint16_(f32_to_f16(v)) {} + inline f16_bits_t(float v) noexcept : uint16_(f32_to_f16(v)) {} + inline f16_bits_t(double v) noexcept : uint16_(f32_to_f16(static_cast(v))) {} + + inline f16_bits_t operator+(f16_bits_t other) const noexcept { return {float(*this) + float(other)}; } + inline f16_bits_t operator-(f16_bits_t other) const noexcept { return {float(*this) - float(other)}; } + inline f16_bits_t operator*(f16_bits_t other) const noexcept { return {float(*this) * float(other)}; } + inline f16_bits_t operator/(f16_bits_t other) const noexcept { return {float(*this) / float(other)}; } + inline f16_bits_t operator+(float other) const noexcept { return {float(*this) + other}; } + inline f16_bits_t operator-(float other) const noexcept { return {float(*this) - other}; } + inline f16_bits_t operator*(float other) const noexcept { return {float(*this) * other}; } + inline f16_bits_t operator/(float other) const noexcept { return {float(*this) / other}; } + inline f16_bits_t operator+(double other) const noexcept { return {float(*this) + other}; } + inline f16_bits_t operator-(double other) const noexcept { return {float(*this) - other}; } + inline f16_bits_t operator*(double other) const noexcept { return {float(*this) * other}; } + inline f16_bits_t operator/(double other) const noexcept { return {float(*this) / other}; } + + inline f16_bits_t& operator+=(float v) noexcept { + uint16_ = f32_to_f16(v + f16_to_f32(uint16_)); + return *this; + } + + inline f16_bits_t& operator-=(float v) noexcept { + uint16_ = f32_to_f16(v - f16_to_f32(uint16_)); + return *this; + } + + inline f16_bits_t& operator*=(float v) noexcept { + uint16_ = f32_to_f16(v * f16_to_f32(uint16_)); + return *this; + } + + inline f16_bits_t& operator/=(float v) noexcept { + uint16_ = f32_to_f16(v / f16_to_f32(uint16_)); + return *this; + } }; /** @@ -384,104 +378,104 @@ class f16_bits_t { * Isn't efficient for small batches, as it recreates the threads on every call. */ class executor_stl_t { - std::size_t threads_count_{}; + std::size_t threads_count_{}; - struct jthread_t { - std::thread native_; + struct jthread_t { + std::thread native_; - jthread_t() = default; - jthread_t(jthread_t&&) = default; - jthread_t(jthread_t const&) = delete; - template jthread_t(callable_at&& func) : native_([=]() { func(); }) {} + jthread_t() = default; + jthread_t(jthread_t&&) = default; + jthread_t(jthread_t const&) = delete; + template jthread_t(callable_at&& func) : native_([=]() { func(); }) {} - ~jthread_t() { - if (native_.joinable()) - native_.join(); - } - }; + ~jthread_t() { + if (native_.joinable()) + native_.join(); + } + }; -public: - /** + public: + /** * @param threads_count The number of threads to be used for parallel execution. - */ - executor_stl_t(std::size_t threads_count = 0) noexcept - : threads_count_(threads_count ? threads_count : std::thread::hardware_concurrency()) {} + */ + executor_stl_t(std::size_t threads_count = 0) noexcept + : threads_count_(threads_count ? threads_count : std::thread::hardware_concurrency()) {} - /** + /** * @return Maximum number of threads available to the executor. - */ - std::size_t size() const noexcept { return threads_count_; } + */ + std::size_t size() const noexcept { return threads_count_; } - /** + /** * @brief Executes a fixed number of tasks using the specified thread-aware function. * @param tasks The total number of tasks to be executed. * @param thread_aware_function The thread-aware function to be called for each thread index and task index. * @throws If an exception occurs during execution of the thread-aware function. - */ - template - void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { - std::vector threads_pool; - std::size_t tasks_per_thread = tasks; - std::size_t threads_count = (std::min)(threads_count_, tasks); - if (threads_count > 1) { - tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); - for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { - threads_pool.emplace_back([=]() { - for (std::size_t task_idx = thread_idx * tasks_per_thread; - task_idx < (std::min)(tasks, thread_idx * tasks_per_thread + tasks_per_thread); ++task_idx) - thread_aware_function(thread_idx, task_idx); - }); - } - } - for (std::size_t task_idx = 0; task_idx < (std::min)(tasks, tasks_per_thread); ++task_idx) - thread_aware_function(0, task_idx); - } - - /** + */ + template + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { + std::vector threads_pool; + std::size_t tasks_per_thread = tasks; + std::size_t threads_count = (std::min)(threads_count_, tasks); + if (threads_count > 1) { + tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); + for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { + threads_pool.emplace_back([=]() { + for (std::size_t task_idx = thread_idx * tasks_per_thread; + task_idx < (std::min)(tasks, thread_idx * tasks_per_thread + tasks_per_thread); ++task_idx) + thread_aware_function(thread_idx, task_idx); + }); + } + } + for (std::size_t task_idx = 0; task_idx < (std::min)(tasks, tasks_per_thread); ++task_idx) + thread_aware_function(0, task_idx); + } + + /** * @brief Executes limited number of tasks using the specified thread-aware function. * @param tasks The upper bound on the number of tasks. * @param thread_aware_function The thread-aware function to be called for each thread index and task index. * @throws If an exception occurs during execution of the thread-aware function. - */ - template - void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { - std::vector threads_pool; - std::size_t tasks_per_thread = tasks; - std::size_t threads_count = (std::min)(threads_count_, tasks); - std::atomic_bool stop{false}; - if (threads_count > 1) { - tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); - for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { - threads_pool.emplace_back([=, &stop]() { - for (std::size_t task_idx = thread_idx * tasks_per_thread; - task_idx < (std::min)(tasks, thread_idx * tasks_per_thread + tasks_per_thread) && - !stop.load(std::memory_order_relaxed); - ++task_idx) - if (!thread_aware_function(thread_idx, task_idx)) - stop.store(true, std::memory_order_relaxed); - }); - } - } - for (std::size_t task_idx = 0; - task_idx < (std::min)(tasks, tasks_per_thread) && !stop.load(std::memory_order_relaxed); ++task_idx) - if (!thread_aware_function(0, task_idx)) - stop.store(true, std::memory_order_relaxed); - } - - /** + */ + template + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { + std::vector threads_pool; + std::size_t tasks_per_thread = tasks; + std::size_t threads_count = (std::min)(threads_count_, tasks); + std::atomic_bool stop{false}; + if (threads_count > 1) { + tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); + for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { + threads_pool.emplace_back([=, &stop]() { + for (std::size_t task_idx = thread_idx * tasks_per_thread; + task_idx < (std::min)(tasks, thread_idx * tasks_per_thread + tasks_per_thread) && + !stop.load(std::memory_order_relaxed); + ++task_idx) + if (!thread_aware_function(thread_idx, task_idx)) + stop.store(true, std::memory_order_relaxed); + }); + } + } + for (std::size_t task_idx = 0; + task_idx < (std::min)(tasks, tasks_per_thread) && !stop.load(std::memory_order_relaxed); ++task_idx) + if (!thread_aware_function(0, task_idx)) + stop.store(true, std::memory_order_relaxed); + } + + /** * @brief Saturates every available thread with the given workload, until they finish. * @param thread_aware_function The thread-aware function to be called for each thread index. * @throws If an exception occurs during execution of the thread-aware function. - */ - template - void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { - if (threads_count_ == 1) - return thread_aware_function(0); - std::vector threads_pool; - for (std::size_t thread_idx = 1; thread_idx < threads_count_; ++thread_idx) - threads_pool.emplace_back([=]() { thread_aware_function(thread_idx); }); - thread_aware_function(0); - } + */ + template + void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { + if (threads_count_ == 1) + return thread_aware_function(0); + std::vector threads_pool; + for (std::size_t thread_idx = 1; thread_idx < threads_count_; ++thread_idx) + threads_pool.emplace_back([=]() { thread_aware_function(thread_idx); }); + thread_aware_function(0); + } }; #if USEARCH_USE_OPENMP @@ -491,71 +485,71 @@ class executor_stl_t { * Is the preferred implementation, when available, and maximum performance is needed. */ class executor_openmp_t { -public: - /** + public: + /** * @param threads_count The number of threads to be used for parallel execution. - */ - executor_openmp_t(std::size_t threads_count = 0) noexcept { - omp_set_num_threads(threads_count ? threads_count : std::thread::hardware_concurrency()); - } + */ + executor_openmp_t(std::size_t threads_count = 0) noexcept { + omp_set_num_threads(static_cast(threads_count ? threads_count : std::thread::hardware_concurrency())); + } - /** + /** * @return Maximum number of threads available to the executor. - */ - std::size_t size() const noexcept { return omp_get_num_threads(); } + */ + std::size_t size() const noexcept { return omp_get_max_threads(); } - /** + /** * @brief Executes tasks in bulk using the specified thread-aware function. * @param tasks The total number of tasks to be executed. * @param thread_aware_function The thread-aware function to be called for each thread index and task index. * @throws If an exception occurs during execution of the thread-aware function. - */ - template - void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { + */ + template + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { #pragma omp parallel for schedule(dynamic, 1) - for (std::size_t i = 0; i != tasks; ++i) { - thread_aware_function(omp_get_thread_num(), i); - } - } + for (std::size_t i = 0; i != tasks; ++i) { + thread_aware_function(omp_get_thread_num(), i); + } + } - /** + /** * @brief Executes tasks in bulk using the specified thread-aware function. * @param tasks The total number of tasks to be executed. * @param thread_aware_function The thread-aware function to be called for each thread index and task index. * @throws If an exception occurs during execution of the thread-aware function. - */ - template - void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { - // OpenMP cancellation points are not yet available on most platforms, and require - // the `OMP_CANCELLATION` environment variable to be set. - // http://jakascorner.com/blog/2016/08/omp-cancel.html - // if (omp_get_cancellation()) { - // #pragma omp parallel for schedule(dynamic, 1) - // for (std::size_t i = 0; i != tasks; ++i) { - // #pragma omp cancellation point for - // if (!thread_aware_function(omp_get_thread_num(), i)) { - // #pragma omp cancel for - // } - // } - // } - std::atomic_bool stop{false}; + */ + template + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { + // OpenMP cancellation points are not yet available on most platforms, and require + // the `OMP_CANCELLATION` environment variable to be set. + // http://jakascorner.com/blog/2016/08/omp-cancel.html + // if (omp_get_cancellation()) { + // #pragma omp parallel for schedule(dynamic, 1) + // for (std::size_t i = 0; i != tasks; ++i) { + // #pragma omp cancellation point for + // if (!thread_aware_function(omp_get_thread_num(), i)) { + // #pragma omp cancel for + // } + // } + // } + std::atomic_bool stop{false}; #pragma omp parallel for schedule(dynamic, 1) shared(stop) - for (std::size_t i = 0; i != tasks; ++i) { - if (!stop.load(std::memory_order_relaxed) && !thread_aware_function(omp_get_thread_num(), i)) - stop.store(true, std::memory_order_relaxed); - } - } + for (std::size_t i = 0; i != tasks; ++i) { + if (!stop.load(std::memory_order_relaxed) && !thread_aware_function(omp_get_thread_num(), i)) + stop.store(true, std::memory_order_relaxed); + } + } - /** + /** * @brief Saturates every available thread with the given workload, until they finish. * @param thread_aware_function The thread-aware function to be called for each thread index. * @throws If an exception occurs during execution of the thread-aware function. - */ - template - void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { + */ + template + void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { #pragma omp parallel - { thread_aware_function(omp_get_thread_num()); } - } + { thread_aware_function(omp_get_thread_num()); } + } }; using executor_default_t = executor_openmp_t; @@ -571,67 +565,67 @@ using executor_default_t = executor_stl_t; */ template // class aligned_allocator_gt { -public: - using value_type = element_at; - using size_type = std::size_t; - using pointer = element_at*; - using const_pointer = element_at const*; - template struct rebind { - using other = aligned_allocator_gt; - }; - - constexpr std::size_t alignment() const { return alignment_ak; } - - pointer allocate(size_type length) const { - std::size_t length_bytes = alignment_ak * divide_round_up(length * sizeof(value_type)); - std::size_t alignment = alignment_ak; - // void* result = nullptr; - // int status = posix_memalign(&result, alignment, length_bytes); - // return status == 0 ? (pointer)result : nullptr; + public: + using value_type = element_at; + using size_type = std::size_t; + using pointer = element_at*; + using const_pointer = element_at const*; + template struct rebind { + using other = aligned_allocator_gt; + }; + + constexpr std::size_t alignment() const { return alignment_ak; } + + pointer allocate(size_type length) const { + std::size_t length_bytes = alignment_ak * divide_round_up(length * sizeof(value_type)); + std::size_t alignment = alignment_ak; + // void* result = nullptr; + // int status = posix_memalign(&result, alignment, length_bytes); + // return status == 0 ? (pointer)result : nullptr; #if defined(USEARCH_DEFINED_WINDOWS) - return (pointer)_aligned_malloc(length_bytes, alignment); + return (pointer)_aligned_malloc(length_bytes, alignment); #else - return (pointer)aligned_alloc(alignment, length_bytes); + return (pointer)aligned_alloc(alignment, length_bytes); #endif - } + } - void deallocate(pointer begin, size_type) const { + void deallocate(pointer begin, size_type) const { #if defined(USEARCH_DEFINED_WINDOWS) - _aligned_free(begin); + _aligned_free(begin); #else - free(begin); + free(begin); #endif - } + } }; using aligned_allocator_t = aligned_allocator_gt<>; class page_allocator_t { -public: - static constexpr std::size_t page_size() { return 4096; } + public: + static constexpr std::size_t page_size() { return 4096; } - /** + /** * @brief Allocates an @b uninitialized block of memory of the specified size. * @param count_bytes The number of bytes to allocate. * @return A pointer to the allocated memory block, or `nullptr` if allocation fails. - */ - byte_t* allocate(std::size_t count_bytes) const noexcept { - count_bytes = divide_round_up(count_bytes, page_size()) * page_size(); + */ + byte_t* allocate(std::size_t count_bytes) const noexcept { + count_bytes = divide_round_up(count_bytes, page_size()) * page_size(); #if defined(USEARCH_DEFINED_WINDOWS) - return (byte_t*)(::VirtualAlloc(NULL, count_bytes, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE)); + return (byte_t*)(::VirtualAlloc(NULL, count_bytes, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE)); #else - return (byte_t*)mmap(NULL, count_bytes, PROT_WRITE | PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0); + return (byte_t*)mmap(NULL, count_bytes, PROT_WRITE | PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0); #endif - } + } - void deallocate(byte_t* page_pointer, std::size_t count_bytes) const noexcept { + void deallocate(byte_t* page_pointer, std::size_t count_bytes) const noexcept { #if defined(USEARCH_DEFINED_WINDOWS) - ::VirtualFree(page_pointer, 0, MEM_RELEASE); + ::VirtualFree(page_pointer, 0, MEM_RELEASE); #else - count_bytes = divide_round_up(count_bytes, page_size()) * page_size(); - munmap(page_pointer, count_bytes); + count_bytes = divide_round_up(count_bytes, page_size()) * page_size(); + munmap(page_pointer, count_bytes); #endif - } + } }; /** @@ -643,135 +637,135 @@ class page_allocator_t { */ template class memory_mapping_allocator_gt { - static constexpr std::size_t min_capacity() { return 1024 * 1024 * 4; } - static constexpr std::size_t capacity_multiplier() { return 2; } - static constexpr std::size_t head_size() { - /// Pointer to the the previous arena and the size of the current one. - return divide_round_up(sizeof(byte_t*) + sizeof(std::size_t)) * alignment_ak; - } - - std::mutex mutex_; - byte_t* last_arena_ = nullptr; - std::size_t last_usage_ = head_size(); - std::size_t last_capacity_ = min_capacity(); - std::size_t wasted_space_ = 0; - -public: - using value_type = byte_t; - using size_type = std::size_t; - using pointer = byte_t*; - using const_pointer = byte_t const*; - - memory_mapping_allocator_gt() = default; - memory_mapping_allocator_gt(memory_mapping_allocator_gt&& other) noexcept - : last_arena_(exchange(other.last_arena_, nullptr)), last_usage_(exchange(other.last_usage_, 0)), - last_capacity_(exchange(other.last_capacity_, 0)), wasted_space_(exchange(other.wasted_space_, 0)) {} - - memory_mapping_allocator_gt& operator=(memory_mapping_allocator_gt&& other) noexcept { - std::swap(last_arena_, other.last_arena_); - std::swap(last_usage_, other.last_usage_); - std::swap(last_capacity_, other.last_capacity_); - std::swap(wasted_space_, other.wasted_space_); - return *this; - } - - ~memory_mapping_allocator_gt() noexcept { reset(); } - - /** + static constexpr std::size_t min_capacity() { return 1024 * 1024 * 4; } + static constexpr std::size_t capacity_multiplier() { return 2; } + static constexpr std::size_t head_size() { + /// Pointer to the the previous arena and the size of the current one. + return divide_round_up(sizeof(byte_t*) + sizeof(std::size_t)) * alignment_ak; + } + + std::mutex mutex_; + byte_t* last_arena_ = nullptr; + std::size_t last_usage_ = head_size(); + std::size_t last_capacity_ = min_capacity(); + std::size_t wasted_space_ = 0; + + public: + using value_type = byte_t; + using size_type = std::size_t; + using pointer = byte_t*; + using const_pointer = byte_t const*; + + memory_mapping_allocator_gt() = default; + memory_mapping_allocator_gt(memory_mapping_allocator_gt&& other) noexcept + : last_arena_(exchange(other.last_arena_, nullptr)), last_usage_(exchange(other.last_usage_, 0)), + last_capacity_(exchange(other.last_capacity_, 0)), wasted_space_(exchange(other.wasted_space_, 0)) {} + + memory_mapping_allocator_gt& operator=(memory_mapping_allocator_gt&& other) noexcept { + std::swap(last_arena_, other.last_arena_); + std::swap(last_usage_, other.last_usage_); + std::swap(last_capacity_, other.last_capacity_); + std::swap(wasted_space_, other.wasted_space_); + return *this; + } + + ~memory_mapping_allocator_gt() noexcept { reset(); } + + /** * @brief Discards all previously allocated memory buffers. - */ - void reset() noexcept { - byte_t* last_arena = last_arena_; - while (last_arena) { - byte_t* previous_arena = nullptr; - std::memcpy(&previous_arena, last_arena, sizeof(byte_t*)); - std::size_t last_cap = 0; - std::memcpy(&last_cap, last_arena + sizeof(byte_t*), sizeof(std::size_t)); - page_allocator_t{}.deallocate(last_arena, last_cap); - last_arena = previous_arena; - } - - // Clear the references: - last_arena_ = nullptr; - last_usage_ = head_size(); - last_capacity_ = min_capacity(); - wasted_space_ = 0; - } - - /** + */ + void reset() noexcept { + byte_t* last_arena = last_arena_; + while (last_arena) { + byte_t* previous_arena = nullptr; + std::memcpy(&previous_arena, last_arena, sizeof(byte_t*)); + std::size_t last_cap = 0; + std::memcpy(&last_cap, last_arena + sizeof(byte_t*), sizeof(std::size_t)); + page_allocator_t{}.deallocate(last_arena, last_cap); + last_arena = previous_arena; + } + + // Clear the references: + last_arena_ = nullptr; + last_usage_ = head_size(); + last_capacity_ = min_capacity(); + wasted_space_ = 0; + } + + /** * @brief Copy constructor. * @note This is a no-op copy constructor since the allocator is not copyable. - */ - memory_mapping_allocator_gt(memory_mapping_allocator_gt const&) noexcept {} + */ + memory_mapping_allocator_gt(memory_mapping_allocator_gt const&) noexcept {} - /** + /** * @brief Copy assignment operator. * @note This is a no-op copy assignment operator since the allocator is not copyable. * @return Reference to the allocator after the assignment. - */ - memory_mapping_allocator_gt& operator=(memory_mapping_allocator_gt const&) noexcept { - reset(); - return *this; - } + */ + memory_mapping_allocator_gt& operator=(memory_mapping_allocator_gt const&) noexcept { + reset(); + return *this; + } - /** + /** * @brief Allocates an @b uninitialized block of memory of the specified size. * @param count_bytes The number of bytes to allocate. * @return A pointer to the allocated memory block, or `nullptr` if allocation fails. - */ - inline byte_t* allocate(std::size_t count_bytes) noexcept { - std::size_t extended_bytes = divide_round_up(count_bytes) * alignment_ak; - std::unique_lock lock(mutex_); - if (!last_arena_ || (last_usage_ + extended_bytes >= last_capacity_)) { - std::size_t new_cap = (std::max)(last_capacity_, ceil2(extended_bytes)) * capacity_multiplier(); - byte_t* new_arena = page_allocator_t{}.allocate(new_cap); - if (!new_arena) - return nullptr; - std::memcpy(new_arena, &last_arena_, sizeof(byte_t*)); - std::memcpy(new_arena + sizeof(byte_t*), &new_cap, sizeof(std::size_t)); - - wasted_space_ += total_reserved(); - last_arena_ = new_arena; - last_capacity_ = new_cap; - last_usage_ = head_size(); - } - - wasted_space_ += extended_bytes - count_bytes; - return last_arena_ + exchange(last_usage_, last_usage_ + extended_bytes); - } - - /** + */ + inline byte_t* allocate(std::size_t count_bytes) noexcept { + std::size_t extended_bytes = divide_round_up(count_bytes) * alignment_ak; + std::unique_lock lock(mutex_); + if (!last_arena_ || (last_usage_ + extended_bytes >= last_capacity_)) { + std::size_t new_cap = (std::max)(last_capacity_, ceil2(extended_bytes)) * capacity_multiplier(); + byte_t* new_arena = page_allocator_t{}.allocate(new_cap); + if (!new_arena) + return nullptr; + std::memcpy(new_arena, &last_arena_, sizeof(byte_t*)); + std::memcpy(new_arena + sizeof(byte_t*), &new_cap, sizeof(std::size_t)); + + wasted_space_ += total_reserved(); + last_arena_ = new_arena; + last_capacity_ = new_cap; + last_usage_ = head_size(); + } + + wasted_space_ += extended_bytes - count_bytes; + return last_arena_ + exchange(last_usage_, last_usage_ + extended_bytes); + } + + /** * @brief Returns the amount of memory used by the allocator across all arenas. * @return The amount of space in bytes. - */ - std::size_t total_allocated() const noexcept { - if (!last_arena_) - return 0; - std::size_t total_used = 0; - std::size_t last_capacity = last_capacity_; - do { - total_used += last_capacity; - last_capacity /= capacity_multiplier(); - } while (last_capacity >= min_capacity()); - return total_used; - } - - /** + */ + std::size_t total_allocated() const noexcept { + if (!last_arena_) + return 0; + std::size_t total_used = 0; + std::size_t last_capacity = last_capacity_; + do { + total_used += last_capacity; + last_capacity /= capacity_multiplier(); + } while (last_capacity >= min_capacity()); + return total_used; + } + + /** * @brief Returns the amount of wasted space due to alignment. * @return The amount of wasted space in bytes. - */ - std::size_t total_wasted() const noexcept { return wasted_space_; } + */ + std::size_t total_wasted() const noexcept { return wasted_space_; } - /** + /** * @brief Returns the amount of remaining memory already reserved but not yet used. * @return The amount of reserved memory in bytes. - */ - std::size_t total_reserved() const noexcept { return last_arena_ ? last_capacity_ - last_usage_ : 0; } + */ + std::size_t total_reserved() const noexcept { return last_arena_ ? last_capacity_ - last_usage_ : 0; } - /** + /** * @warning The very first memory de-allocation discards all the arenas! - */ - void deallocate(byte_t* = nullptr, std::size_t = 0) noexcept { reset(); } + */ + void deallocate(byte_t* = nullptr, std::size_t = 0) noexcept { reset(); } }; using memory_mapping_allocator_t = memory_mapping_allocator_gt<>; @@ -782,86 +776,86 @@ using memory_mapping_allocator_t = memory_mapping_allocator_gt<>; * but requires only a single 32-bit atomic integer to work. */ class unfair_shared_mutex_t { - /** Any positive integer describes the number of concurrent readers */ - enum state_t : std::int32_t { - idle_k = 0, - writing_k = -1, - }; - std::atomic state_{idle_k}; - -public: - inline void lock() noexcept { - std::int32_t raw; - relock: - raw = idle_k; - if (!state_.compare_exchange_weak(raw, writing_k, std::memory_order_acquire, std::memory_order_relaxed)) { - std::this_thread::yield(); - goto relock; - } - } - - inline void unlock() noexcept { state_.store(idle_k, std::memory_order_release); } - - inline void lock_shared() noexcept { - std::int32_t raw; - relock_shared: - raw = state_.load(std::memory_order_acquire); - // Spin while it's uniquely locked - if (raw == writing_k) { - std::this_thread::yield(); - goto relock_shared; - } - // Try incrementing the counter - if (!state_.compare_exchange_weak(raw, raw + 1, std::memory_order_acquire, std::memory_order_relaxed)) { - std::this_thread::yield(); - goto relock_shared; - } - } - - inline void unlock_shared() noexcept { state_.fetch_sub(1, std::memory_order_release); } - - /** + /** Any positive integer describes the number of concurrent readers */ + enum state_t : std::int32_t { + idle_k = 0, + writing_k = -1, + }; + std::atomic state_{idle_k}; + + public: + inline void lock() noexcept { + std::int32_t raw; + relock: + raw = idle_k; + if (!state_.compare_exchange_weak(raw, writing_k, std::memory_order_acquire, std::memory_order_relaxed)) { + std::this_thread::yield(); + goto relock; + } + } + + inline void unlock() noexcept { state_.store(idle_k, std::memory_order_release); } + + inline void lock_shared() noexcept { + std::int32_t raw; + relock_shared: + raw = state_.load(std::memory_order_acquire); + // Spin while it's uniquely locked + if (raw == writing_k) { + std::this_thread::yield(); + goto relock_shared; + } + // Try incrementing the counter + if (!state_.compare_exchange_weak(raw, raw + 1, std::memory_order_acquire, std::memory_order_relaxed)) { + std::this_thread::yield(); + goto relock_shared; + } + } + + inline void unlock_shared() noexcept { state_.fetch_sub(1, std::memory_order_release); } + + /** * @brief Try upgrades the current `lock_shared()` to a unique `lock()` state. - */ - inline bool try_escalate() noexcept { - std::int32_t one_read = 1; - return state_.compare_exchange_weak(one_read, writing_k, std::memory_order_acquire, std::memory_order_relaxed); - } + */ + inline bool try_escalate() noexcept { + std::int32_t one_read = 1; + return state_.compare_exchange_weak(one_read, writing_k, std::memory_order_acquire, std::memory_order_relaxed); + } - /** + /** * @brief Escalates current lock potentially loosing control in the middle. * It's a shortcut for `try_escalate`-`unlock_shared`-`lock` trio. - */ - inline void unsafe_escalate() noexcept { - if (!try_escalate()) { - unlock_shared(); - lock(); - } - } - - /** + */ + inline void unsafe_escalate() noexcept { + if (!try_escalate()) { + unlock_shared(); + lock(); + } + } + + /** * @brief Upgrades the current `lock_shared()` to a unique `lock()` state. - */ - inline void escalate() noexcept { - while (!try_escalate()) - std::this_thread::yield(); - } + */ + inline void escalate() noexcept { + while (!try_escalate()) + std::this_thread::yield(); + } - /** + /** * @brief De-escalation of a previously escalated state. - */ - inline void de_escalate() noexcept { - std::int32_t one_read = 1; - state_.store(one_read, std::memory_order_release); - } + */ + inline void de_escalate() noexcept { + std::int32_t one_read = 1; + state_.store(one_read, std::memory_order_release); + } }; template class shared_lock_gt { - mutex_at& mutex_; + mutex_at& mutex_; -public: - inline explicit shared_lock_gt(mutex_at& m) noexcept : mutex_(m) { mutex_.lock_shared(); } - inline ~shared_lock_gt() noexcept { mutex_.unlock_shared(); } + public: + inline explicit shared_lock_gt(mutex_at& m) noexcept : mutex_(m) { mutex_.lock_shared(); } + inline ~shared_lock_gt() noexcept { mutex_.unlock_shared(); } }; /** @@ -869,53 +863,65 @@ template class shared_lock_gt { * avoiding unnecessary conversions. */ template struct cast_gt { - inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { - from_scalar_at const* typed_input = reinterpret_cast(input); - to_scalar_at* typed_output = reinterpret_cast(output); - auto converter = [](from_scalar_at from) { return to_scalar_at(from); }; - std::transform(typed_input, typed_input + dim, typed_output, converter); - return true; - } + inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { + from_scalar_at const* typed_input = reinterpret_cast(input); + to_scalar_at* typed_output = reinterpret_cast(output); + auto converter = [](from_scalar_at from) { return to_scalar_at(from); }; + std::transform(typed_input, typed_input + dim, typed_output, converter); + return true; + } }; template <> struct cast_gt { - bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } }; template <> struct cast_gt { - bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } }; template <> struct cast_gt { - bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } }; template <> struct cast_gt { - bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } }; template <> struct cast_gt { - bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } + bool operator()(byte_t const*, std::size_t, byte_t*) const { return false; } }; template struct cast_gt { - inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { - from_scalar_at const* typed_input = reinterpret_cast(input); - unsigned char* typed_output = reinterpret_cast(output); - for (std::size_t i = 0; i != dim; ++i) - typed_output[i / CHAR_BIT] |= bool(typed_input[i]) ? (128 >> (i & (CHAR_BIT - 1))) : 0; - return true; - } + inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { + from_scalar_at const* typed_input = reinterpret_cast(input); + unsigned char* typed_output = reinterpret_cast(output); + for (std::size_t i = 0; i != dim; ++i) + // Converting from scalar types to boolean isn't trivial and depends on the type. + // The most common case is to consider all positive values as `true` and all others as `false`. + // - `bool(0.00001f)` converts to 1 + // - `bool(-0.00001f)` converts to 1 + // - `bool(0)` converts to 0 + // - `bool(-0)` converts to 0 + // - `bool(std::numeric_limits::infinity())` converts to 1 + // - `bool(std::numeric_limits::epsilon())` converts to 1 + // - `bool(std::numeric_limits::signaling_NaN())` converts to 1 + // - `bool(std::numeric_limits::denorm_min())` converts to 1 + typed_output[i / CHAR_BIT] |= bool(typed_input[i] > 0) ? (128 >> (i & (CHAR_BIT - 1))) : 0; + return true; + } }; template struct cast_gt { - inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { - unsigned char const* typed_input = reinterpret_cast(input); - to_scalar_at* typed_output = reinterpret_cast(output); - for (std::size_t i = 0; i != dim; ++i) - typed_output[i] = bool(typed_input[i / CHAR_BIT] & (128 >> (i & (CHAR_BIT - 1)))); - return true; - } + inline bool operator()(byte_t const* input, std::size_t dim, byte_t* output) const { + unsigned char const* typed_input = reinterpret_cast(input); + to_scalar_at* typed_output = reinterpret_cast(output); + for (std::size_t i = 0; i != dim; ++i) + // We can't entirely reconstruct the original scalar type from a boolean. + // The simplest variant would be to map set bits to ones, and unset bits to zeros. + typed_output[i] = bool(typed_input[i / CHAR_BIT] & (128 >> (i & (CHAR_BIT - 1)))); + return true; + } }; /** @@ -923,36 +929,36 @@ template struct cast_gt { * values within [-1,1] range, quantized to integers [-100,100]. */ class i8_converted_t { - std::int8_t int8_{}; - -public: - constexpr static f32_t divisor_k = 100.f; - constexpr static std::int8_t min_k = -100; - constexpr static std::int8_t max_k = 100; - - inline i8_converted_t() noexcept : int8_(0) {} - inline i8_converted_t(bool v) noexcept : int8_(v ? max_k : 0) {} - - inline i8_converted_t(i8_converted_t&&) = default; - inline i8_converted_t& operator=(i8_converted_t&&) = default; - inline i8_converted_t(i8_converted_t const&) = default; - inline i8_converted_t& operator=(i8_converted_t const&) = default; - - inline operator f16_t() const noexcept { return static_cast(f32_t(int8_) / divisor_k); } - inline operator f32_t() const noexcept { return f32_t(int8_) / divisor_k; } - inline operator f64_t() const noexcept { return f64_t(int8_) / divisor_k; } - inline explicit operator bool() const noexcept { return int8_ > (max_k / 2); } - inline explicit operator std::int8_t() const noexcept { return int8_; } - inline explicit operator std::int16_t() const noexcept { return int8_; } - inline explicit operator std::int32_t() const noexcept { return int8_; } - inline explicit operator std::int64_t() const noexcept { return int8_; } - - inline i8_converted_t(f16_t v) - : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} - inline i8_converted_t(f32_t v) - : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} - inline i8_converted_t(f64_t v) - : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} + std::int8_t int8_{}; + + public: + constexpr static f32_t divisor_k = 100.f; + constexpr static std::int8_t min_k = -100; + constexpr static std::int8_t max_k = 100; + + inline i8_converted_t() noexcept : int8_(0) {} + inline i8_converted_t(bool v) noexcept : int8_(v ? max_k : 0) {} + + inline i8_converted_t(i8_converted_t&&) = default; + inline i8_converted_t& operator=(i8_converted_t&&) = default; + inline i8_converted_t(i8_converted_t const&) = default; + inline i8_converted_t& operator=(i8_converted_t const&) = default; + + inline operator f16_t() const noexcept { return static_cast(f32_t(int8_) / divisor_k); } + inline operator f32_t() const noexcept { return f32_t(int8_) / divisor_k; } + inline operator f64_t() const noexcept { return f64_t(int8_) / divisor_k; } + inline explicit operator bool() const noexcept { return int8_ > (max_k / 2); } + inline explicit operator std::int8_t() const noexcept { return int8_; } + inline explicit operator std::int16_t() const noexcept { return int8_; } + inline explicit operator std::int32_t() const noexcept { return int8_; } + inline explicit operator std::int64_t() const noexcept { return int8_; } + + inline i8_converted_t(f16_t v) + : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} + inline i8_converted_t(f32_t v) + : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} + inline i8_converted_t(f64_t v) + : int8_(usearch::clamp(static_cast(v * divisor_k), min_k, max_k)) {} }; f16_bits_t::f16_bits_t(i8_converted_t v) noexcept : uint16_(f32_to_f16(v)) {} @@ -969,11 +975,11 @@ template <> struct cast_gt : public cast_gt * @brief Inner (Dot) Product distance. */ template struct metric_ip_gt { - using scalar_t = scalar_at; - using result_t = result_at; + using scalar_t = scalar_at; + using result_t = result_at; - inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { - result_t ab{}; + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : ab) #elif defined(USEARCH_DEFINED_CLANG) @@ -981,10 +987,10 @@ template struct met #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != dim; ++i) - ab += result_t(a[i]) * result_t(b[i]); - return 1 - ab; - } + for (std::size_t i = 0; i != dim; ++i) + ab += result_t(a[i]) * result_t(b[i]); + return 1 - ab; + } }; /** @@ -994,11 +1000,11 @@ template struct met * is recommended over `::metric_ip_gt` for low-precision scalars. */ template struct metric_cos_gt { - using scalar_t = scalar_at; - using result_t = result_at; + using scalar_t = scalar_at; + using result_t = result_at; - inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { - result_t ab{}, a2{}, b2{}; + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab{}, a2{}, b2{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : ab, a2, b2) #elif defined(USEARCH_DEFINED_CLANG) @@ -1006,18 +1012,18 @@ template struct met #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != dim; ++i) { - result_t ai = static_cast(a[i]); - result_t bi = static_cast(b[i]); - ab += ai * bi, a2 += square(ai), b2 += square(bi); - } - - result_t result_if_zero[2][2]; - result_if_zero[0][0] = 1 - ab / (std::sqrt(a2) * std::sqrt(b2)); - result_if_zero[0][1] = result_if_zero[1][0] = 1; - result_if_zero[1][1] = 0; - return result_if_zero[a2 == 0][b2 == 0]; - } + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + ab += ai * bi, a2 += square(ai), b2 += square(bi); + } + + result_t result_if_zero[2][2]; + result_if_zero[0][0] = 1 - ab / (std::sqrt(a2) * std::sqrt(b2)); + result_if_zero[0][1] = result_if_zero[1][0] = 1; + result_if_zero[1][1] = 0; + return result_if_zero[a2 == 0][b2 == 0]; + } }; /** @@ -1025,11 +1031,11 @@ template struct met * Square root is avoided at the end, as it won't affect the ordering. */ template struct metric_l2sq_gt { - using scalar_t = scalar_at; - using result_t = result_at; + using scalar_t = scalar_at; + using result_t = result_at; - inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { - result_t ab_deltas_sq{}; + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab_deltas_sq{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : ab_deltas_sq) #elif defined(USEARCH_DEFINED_CLANG) @@ -1037,13 +1043,13 @@ template struct met #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != dim; ++i) { - result_t ai = static_cast(a[i]); - result_t bi = static_cast(b[i]); - ab_deltas_sq += square(ai - bi); - } - return ab_deltas_sq; - } + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + ab_deltas_sq += square(ai - bi); + } + return ab_deltas_sq; + } }; /** @@ -1052,16 +1058,16 @@ template struct met * tokenized and hashed into a fixed-capacity bitset. */ template struct metric_hamming_gt { - using scalar_t = scalar_at; - using result_t = result_at; - static_assert( // - std::is_unsigned::value || - (std::is_enum::value && std::is_unsigned::type>::value), - "Hamming distance requires unsigned integral words"); - - inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { - constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; - result_t matches{}; + using scalar_t = scalar_at; + using result_t = result_at; + static_assert( // + std::is_unsigned::value || + (std::is_enum::value && std::is_unsigned::type>::value), + "Hamming distance requires unsigned integral words"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { + constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; + result_t matches{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : matches) #elif defined(USEARCH_DEFINED_CLANG) @@ -1069,10 +1075,10 @@ template #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != words; ++i) - matches += std::bitset(a[i] ^ b[i]).count(); - return matches; - } + for (std::size_t i = 0; i != words; ++i) + matches += std::bitset(a[i] ^ b[i]).count(); + return matches; + } }; /** @@ -1080,18 +1086,18 @@ template * Often used in chemistry and biology to compare molecular fingerprints. */ template struct metric_tanimoto_gt { - using scalar_t = scalar_at; - using result_t = result_at; - static_assert( // - std::is_unsigned::value || - (std::is_enum::value && std::is_unsigned::type>::value), - "Tanimoto distance requires unsigned integral words"); - static_assert(std::is_floating_point::value, "Tanimoto distance will be a fraction"); - - inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { - constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; - result_t and_count{}; - result_t or_count{}; + using scalar_t = scalar_at; + using result_t = result_at; + static_assert( // + std::is_unsigned::value || + (std::is_enum::value && std::is_unsigned::type>::value), + "Tanimoto distance requires unsigned integral words"); + static_assert(std::is_floating_point::value, "Tanimoto distance will be a fraction"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { + constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; + result_t and_count{}; + result_t or_count{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : and_count, or_count) #elif defined(USEARCH_DEFINED_CLANG) @@ -1099,12 +1105,12 @@ template struct #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != words; ++i) { - and_count += std::bitset(a[i] & b[i]).count(); - or_count += std::bitset(a[i] | b[i]).count(); - } - return 1 - result_t(and_count) / or_count; - } + for (std::size_t i = 0; i != words; ++i) { + and_count += std::bitset(a[i] & b[i]).count(); + or_count += std::bitset(a[i] | b[i]).count(); + } + return 1 - result_t(and_count) / or_count; + } }; /** @@ -1112,18 +1118,18 @@ template struct * Often used in chemistry and biology to compare molecular fingerprints. */ template struct metric_sorensen_gt { - using scalar_t = scalar_at; - using result_t = result_at; - static_assert( // - std::is_unsigned::value || - (std::is_enum::value && std::is_unsigned::type>::value), - "Sorensen-Dice distance requires unsigned integral words"); - static_assert(std::is_floating_point::value, "Sorensen-Dice distance will be a fraction"); - - inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { - constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; - result_t and_count{}; - result_t any_count{}; + using scalar_t = scalar_at; + using result_t = result_at; + static_assert( // + std::is_unsigned::value || + (std::is_enum::value && std::is_unsigned::type>::value), + "Sorensen-Dice distance requires unsigned integral words"); + static_assert(std::is_floating_point::value, "Sorensen-Dice distance will be a fraction"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { + constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; + result_t and_count{}; + result_t any_count{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : and_count, any_count) #elif defined(USEARCH_DEFINED_CLANG) @@ -1131,12 +1137,12 @@ template struct #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != words; ++i) { - and_count += std::bitset(a[i] & b[i]).count(); - any_count += std::bitset(a[i]).count() + std::bitset(b[i]).count(); - } - return 1 - 2 * result_t(and_count) / any_count; - } + for (std::size_t i = 0; i != words; ++i) { + and_count += std::bitset(a[i] & b[i]).count(); + any_count += std::bitset(a[i]).count() + std::bitset(b[i]).count(); + } + return 1 - 2 * result_t(and_count) / any_count; + } }; /** @@ -1146,41 +1152,44 @@ template struct * Similar to `metric_tanimoto_gt` for dense representations. */ template struct metric_jaccard_gt { - using scalar_t = scalar_at; - using result_t = result_at; - static_assert(!std::is_floating_point::value, "Jaccard distance requires integral scalars"); - - inline result_t operator()( // - scalar_t const* a, scalar_t const* b, std::size_t a_length, std::size_t b_length) const noexcept { - result_t intersection{}; - std::size_t i{}; - std::size_t j{}; - while (i != a_length && j != b_length) { - intersection += a[i] == b[j]; - i += a[i] < b[j]; - j += a[i] >= b[j]; - } - return 1 - intersection / (a_length + b_length - intersection); - } + using scalar_t = scalar_at; + using result_t = result_at; + static_assert(!std::is_floating_point::value, "Jaccard distance requires integral scalars"); + static_assert(std::is_floating_point::value, "Jaccard distance returns a fraction"); + + inline result_t operator()( // + scalar_t const* a, scalar_t const* b, std::size_t a_length, std::size_t b_length) const noexcept { + std::size_t intersection{}; + std::size_t i{}; + std::size_t j{}; + while (i != a_length && j != b_length) { + scalar_t ai = a[i]; + scalar_t bj = b[j]; + intersection += ai == bj; + i += ai < bj; + j += ai >= bj; + } + return 1 - static_cast(intersection) / (a_length + b_length - intersection); + } }; /** * @brief Measures Pearson Correlation between two sequences in a single pass. */ template struct metric_pearson_gt { - using scalar_t = scalar_at; - using result_t = result_at; - - inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { - // The correlation coefficient can't be defined for one or zero-dimensional data. - if (dim <= 1) - return 0; - // Conventional Pearson Correlation Coefficient definiton subtracts the mean value of each - // sequence from each element, before dividing them. WikiPedia article suggests a convenient - // single-pass algorithm for calculating sample correlations, though depending on the numbers - // involved, it can sometimes be numerically unstable. - result_t a_sum{}, b_sum{}, ab_sum{}; - result_t a_sq_sum{}, b_sq_sum{}; + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + // The correlation coefficient can't be defined for one or zero-dimensional data. + if (dim <= 1) + return 0; + // Conventional Pearson Correlation Coefficient definiton subtracts the mean value of each + // sequence from each element, before dividing them. WikiPedia article suggests a convenient + // single-pass algorithm for calculating sample correlations, though depending on the numbers + // involved, it can sometimes be numerically unstable. + result_t a_sum{}, b_sum{}, ab_sum{}; + result_t a_sq_sum{}, b_sq_sum{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : a_sum, b_sum, ab_sum, a_sq_sum, b_sq_sum) #elif defined(USEARCH_DEFINED_CLANG) @@ -1188,34 +1197,36 @@ template struct metric_ #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != dim; ++i) { - result_t ai = static_cast(a[i]); - result_t bi = static_cast(b[i]); - a_sum += ai; - b_sum += bi; - ab_sum += ai * bi; - a_sq_sum += ai * ai; - b_sq_sum += bi * bi; - } - result_t denom = (dim * a_sq_sum - a_sum * a_sum) * (dim * b_sq_sum - b_sum * b_sum); - if (denom == 0) - return 0; - result_t corr = dim * ab_sum - a_sum * b_sum; - denom = std::sqrt(denom); - return -corr / denom; - } + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + a_sum += ai; + b_sum += bi; + ab_sum += ai * bi; + a_sq_sum += ai * ai; + b_sq_sum += bi * bi; + } + result_t denom = (dim * a_sq_sum - a_sum * a_sum) * (dim * b_sq_sum - b_sum * b_sum); + if (denom == 0) + return 0; + result_t corr = dim * ab_sum - a_sum * b_sum; + denom = std::sqrt(denom); + // The normal Pearson correlation value is between -1 and 1, but we are looking for a distance. + // So instead of returning `corr / denom`, we return `1 - corr / denom`. + return 1 - corr / denom; + } }; /** * @brief Measures Jensen-Shannon Divergence between two probability distributions. */ template struct metric_divergence_gt { - using scalar_t = scalar_at; - using result_t = result_at; + using scalar_t = scalar_at; + using result_t = result_at; - inline result_t operator()(scalar_t const* p, scalar_t const* q, std::size_t dim) const noexcept { - result_t kld_pm{}, kld_qm{}; - result_t epsilon = std::numeric_limits::epsilon(); + inline result_t operator()(scalar_t const* p, scalar_t const* q, std::size_t dim) const noexcept { + result_t kld_pm{}, kld_qm{}; + result_t epsilon = std::numeric_limits::epsilon(); #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : kld_pm, kld_qm) #elif defined(USEARCH_DEFINED_CLANG) @@ -1223,23 +1234,23 @@ template struct metric_ #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != dim; ++i) { - result_t pi = static_cast(p[i]); - result_t qi = static_cast(q[i]); - result_t mi = (pi + qi) / 2 + epsilon; - kld_pm += pi * std::log((pi + epsilon) / mi); - kld_qm += qi * std::log((qi + epsilon) / mi); - } - return (kld_pm + kld_qm) / 2; - } + for (std::size_t i = 0; i != dim; ++i) { + result_t pi = static_cast(p[i]); + result_t qi = static_cast(q[i]); + result_t mi = (pi + qi) / 2 + epsilon; + kld_pm += pi * std::log((pi + epsilon) / mi); + kld_qm += qi * std::log((qi + epsilon) / mi); + } + return (kld_pm + kld_qm) / 2; + } }; struct cos_i8_t { - using scalar_t = i8_t; - using result_t = f32_t; + using scalar_t = i8_t; + using result_t = f32_t; - inline result_t operator()(i8_t const* a, i8_t const* b, std::size_t dim) const noexcept { - std::int32_t ab{}, a2{}, b2{}; + inline result_t operator()(i8_t const* a, i8_t const* b, std::size_t dim) const noexcept { + std::int32_t ab{}, a2{}, b2{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : ab, a2, b2) #elif defined(USEARCH_DEFINED_CLANG) @@ -1247,23 +1258,25 @@ struct cos_i8_t { #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != dim; i++) { - std::int16_t ai{a[i]}; - std::int16_t bi{b[i]}; - ab += ai * bi; - a2 += square(ai); - b2 += square(bi); - } - return (ab != 0) ? (1.f - ab / (std::sqrt(a2) * std::sqrt(b2))) : 0; - } + for (std::size_t i = 0; i != dim; i++) { + std::int16_t ai{a[i]}; + std::int16_t bi{b[i]}; + ab += ai * bi; + a2 += square(ai); + b2 += square(bi); + } + result_t a2f = std::sqrt(static_cast(a2)); + result_t b2f = std::sqrt(static_cast(b2)); + return (ab != 0) ? (1.f - ab / (a2f * b2f)) : 0; + } }; struct l2sq_i8_t { - using scalar_t = i8_t; - using result_t = f32_t; + using scalar_t = i8_t; + using result_t = f32_t; - inline result_t operator()(i8_t const* a, i8_t const* b, std::size_t dim) const noexcept { - std::int32_t ab_deltas_sq{}; + inline result_t operator()(i8_t const* a, i8_t const* b, std::size_t dim) const noexcept { + std::int32_t ab_deltas_sq{}; #if USEARCH_USE_OPENMP #pragma omp simd reduction(+ : ab_deltas_sq) #elif defined(USEARCH_DEFINED_CLANG) @@ -1271,10 +1284,10 @@ struct l2sq_i8_t { #elif defined(USEARCH_DEFINED_GCC) #pragma GCC ivdep #endif - for (std::size_t i = 0; i != dim; i++) - ab_deltas_sq += square(std::int16_t(a[i]) - std::int16_t(b[i])); - return static_cast(ab_deltas_sq); - } + for (std::size_t i = 0; i != dim; i++) + ab_deltas_sq += square(std::int16_t(a[i]) - std::int16_t(b[i])); + return static_cast(ab_deltas_sq); + } }; /** @@ -1282,25 +1295,26 @@ struct l2sq_i8_t { * the surface of a 3D sphere, defined with latitude and longitude. */ template struct metric_haversine_gt { - using scalar_t = scalar_at; - using result_t = result_at; - static_assert(!std::is_integral::value, "Latitude and longitude must be floating-node"); + using scalar_t = scalar_at; + using result_t = result_at; + static_assert(!std::is_integral::value && !std::is_same::value, + "Latitude and longitude must be floating-node"); - inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t = 2) const noexcept { - result_t lat_a = a[0], lon_a = a[1]; - result_t lat_b = b[0], lon_b = b[1]; + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t = 2) const noexcept { + result_t lat_a = a[0], lon_a = a[1]; + result_t lat_b = b[0], lon_b = b[1]; - result_t lat_delta = angle_to_radians(lat_b - lat_a) / 2; - result_t lon_delta = angle_to_radians(lon_b - lon_a) / 2; + result_t lat_delta = angle_to_radians(lat_b - lat_a) / 2; + result_t lon_delta = angle_to_radians(lon_b - lon_a) / 2; - result_t converted_lat_a = angle_to_radians(lat_a); - result_t converted_lat_b = angle_to_radians(lat_b); + result_t converted_lat_a = angle_to_radians(lat_a); + result_t converted_lat_b = angle_to_radians(lat_b); - result_t x = square(std::sin(lat_delta)) + // - std::cos(converted_lat_a) * std::cos(converted_lat_b) * square(std::sin(lon_delta)); + result_t x = square(std::sin(lat_delta)) + // + std::cos(converted_lat_a) * std::cos(converted_lat_b) * square(std::sin(lon_delta)); - return 2 * std::asin(std::sqrt(x)); - } + return 2 * std::asin(std::sqrt(x)); + } }; using distance_punned_t = float; @@ -1312,231 +1326,327 @@ using span_punned_t = span_gt; * or include one or two array sizes as 64-bit unsigned integers. */ enum class metric_punned_signature_t { - array_array_k = 0, - array_array_size_k, + array_array_k = 0, + array_array_size_k, + array_array_state_k, }; /** * @brief Type-punned metric class, which unlike STL's `std::function` avoids any memory allocations. * It also provides additional APIs to check, if SIMD hardware-acceleration is available. - * Wraps the `simsimd_metric_punned_t`. + * Wraps the `simsimd_metric_punned_t` when available. The auto-vectorized backend otherwise. */ class metric_punned_t { -public: - using scalar_t = byte_t; - using result_t = distance_punned_t; - -private: - using punned_arg_t = std::size_t; - using punned_ptr_t = result_t (*)(std::size_t, std::size_t, std::size_t, std::size_t); - - punned_ptr_t raw_ptr_ = nullptr; - punned_arg_t raw_arg3_ = 0; - punned_arg_t raw_arg4_ = 0; - - std::size_t dimensions_ = 0; - metric_kind_t metric_kind_ = metric_kind_t::unknown_k; - scalar_kind_t scalar_kind_ = scalar_kind_t::unknown_k; + public: + using scalar_t = byte_t; + using result_t = distance_punned_t; + + private: + /// In the generalized function API all the are arguments are pointer-sized. + using uptr_t = std::size_t; + /// Distance function that takes two arrays and returns a scalar. + using metric_array_array_t = result_t (*)(uptr_t, uptr_t); + /// Distance function that takes two arrays and their length and returns a scalar. + using metric_array_array_size_t = result_t (*)(uptr_t, uptr_t, uptr_t); + /// Distance function that takes two arrays and some callback state and returns a scalar. + using metric_array_array_state_t = result_t (*)(uptr_t, uptr_t, uptr_t); + /// Distance function callback, like `metric_array_array_size_t`, but depends on member variables. + using metric_rounted_t = result_t (metric_punned_t::*)(uptr_t, uptr_t) const; + + metric_rounted_t metric_routed_ = nullptr; + uptr_t metric_ptr_ = 0; + uptr_t metric_third_arg_ = 0; + + std::size_t dimensions_ = 0; + metric_kind_t metric_kind_ = metric_kind_t::unknown_k; + scalar_kind_t scalar_kind_ = scalar_kind_t::unknown_k; #if USEARCH_USE_SIMSIMD - simsimd_capability_t isa_kind_ = simsimd_cap_serial_k; + simsimd_capability_t isa_kind_ = simsimd_cap_serial_k; #endif -public: - /** + public: + /** * @brief Computes the distance between two vectors of fixed length. * * ! This is the only relevant function in the object. Everything else is just dynamic dispatch logic. - */ - inline result_t operator()(byte_t const* a, byte_t const* b) const noexcept { - return raw_ptr_(reinterpret_cast(a), reinterpret_cast(b), raw_arg3_, raw_arg4_); - } - - inline metric_punned_t() noexcept = default; - inline metric_punned_t(metric_punned_t const&) noexcept = default; - inline metric_punned_t& operator=(metric_punned_t const&) noexcept = default; - inline metric_punned_t( // - std::size_t dimensions, // - metric_kind_t metric_kind = metric_kind_t::l2sq_k, // - scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept - : raw_arg3_(dimensions), raw_arg4_(dimensions), dimensions_(dimensions), metric_kind_(metric_kind), - scalar_kind_(scalar_kind) { + */ + inline result_t operator()(byte_t const* a, byte_t const* b) const noexcept { + return (this->*metric_routed_)(reinterpret_cast(a), reinterpret_cast(b)); + } + + inline metric_punned_t() noexcept = default; + inline metric_punned_t(metric_punned_t const&) noexcept = default; + inline metric_punned_t& operator=(metric_punned_t const&) noexcept = default; + + inline metric_punned_t(std::size_t dimensions, metric_kind_t metric_kind = metric_kind_t::l2sq_k, + scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept + : metric_punned_t(builtin(dimensions, metric_kind, scalar_kind)) {} + + inline metric_punned_t(std::size_t dimensions, std::uintptr_t metric_uintptr, metric_punned_signature_t signature, + metric_kind_t metric_kind, scalar_kind_t scalar_kind) noexcept + : metric_punned_t(stateless(dimensions, metric_uintptr, signature, metric_kind, scalar_kind)) {} + + /** + * @brief Creates a metric of a natively supported kind, choosing the best + * available backend internally or from SimSIMD. + * + * @param dimensions The number of elements in the input arrays. + * @param metric_kind The kind of metric to use. + * @param scalar_kind The kind of scalar to use. + * @return A metric object that can be used to compute distances between vectors. + */ + inline static metric_punned_t builtin(std::size_t dimensions, metric_kind_t metric_kind = metric_kind_t::l2sq_k, + scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept { + metric_punned_t metric; + metric.metric_routed_ = &metric_punned_t::invoke_array_array_third; + metric.metric_ptr_ = 0; + metric.metric_third_arg_ = + scalar_kind == scalar_kind_t::b1x8_k ? divide_round_up(dimensions) : dimensions; + metric.dimensions_ = dimensions; + metric.metric_kind_ = metric_kind; + metric.scalar_kind_ = scalar_kind; #if USEARCH_USE_SIMSIMD - if (!configure_with_simsimd()) - configure_with_auto_vectorized(); + if (!metric.configure_with_simsimd()) + metric.configure_with_autovec(); #else - configure_with_auto_vectorized(); + metric.configure_with_autovec(); #endif - if (scalar_kind == scalar_kind_t::b1x8_k) - raw_arg3_ = raw_arg4_ = divide_round_up(dimensions_); - } + return metric; + } - inline metric_punned_t( // - std::size_t dimensions, // - std::uintptr_t metric_uintptr, metric_punned_signature_t signature, // - metric_kind_t metric_kind, // - scalar_kind_t scalar_kind) noexcept - : raw_ptr_(reinterpret_cast(metric_uintptr)), dimensions_(dimensions), metric_kind_(metric_kind), - scalar_kind_(scalar_kind) { - - // We don't need to explicitly parse signature, as all of them are compatible. - (void)signature; - } + /** + * @brief Creates a metric using the provided function pointer for a stateless metric. + * So the provided ::metric_uintptr is a pointer to a function that takes two arrays + * and returns a scalar. If the ::signature is metric_punned_signature_t::array_array_size_k, + * then the third argument is the number of scalar words in the input vectors. + * + * @param dimensions The number of elements in the input arrays. + * @param metric_uintptr The function pointer to the metric function. + * @param signature The signature of the metric function. + * @param metric_kind The kind of metric to use. + * @param scalar_kind The kind of scalar to use. + * @return A metric object that can be used to compute distances between vectors. + */ + inline static metric_punned_t stateless(std::size_t dimensions, std::uintptr_t metric_uintptr, + metric_punned_signature_t signature, metric_kind_t metric_kind, + scalar_kind_t scalar_kind) noexcept { + metric_punned_t metric; + metric.metric_routed_ = signature == metric_punned_signature_t::array_array_k + ? &metric_punned_t::invoke_array_array + : &metric_punned_t::invoke_array_array_third; + metric.metric_ptr_ = metric_uintptr; + metric.metric_third_arg_ = + scalar_kind == scalar_kind_t::b1x8_k ? divide_round_up(dimensions) : dimensions; + metric.dimensions_ = dimensions; + metric.metric_kind_ = metric_kind; + metric.scalar_kind_ = scalar_kind; + return metric; + } + + /** + * @brief Creates a metric using the provided function pointer for a statefull metric. + * The third argument is the state that will be passed to the metric function. + * + * @param metric_uintptr The function pointer to the metric function. + * @param metric_state The state to pass to the metric function. + * @param metric_kind The kind of metric to use. + * @param scalar_kind The kind of scalar to use. + * @return A metric object that can be used to compute distances between vectors. + */ + inline static metric_punned_t statefull(std::uintptr_t metric_uintptr, std::uintptr_t metric_state, + metric_kind_t metric_kind = metric_kind_t::unknown_k, + scalar_kind_t scalar_kind = scalar_kind_t::unknown_k) noexcept { + metric_punned_t metric; + metric.metric_routed_ = &metric_punned_t::invoke_array_array_third; + metric.metric_ptr_ = metric_uintptr; + metric.metric_third_arg_ = metric_state; + metric.dimensions_ = 0; + metric.metric_kind_ = metric_kind; + metric.scalar_kind_ = scalar_kind; + return metric; + } + + inline std::size_t dimensions() const noexcept { return dimensions_; } + inline metric_kind_t metric_kind() const noexcept { return metric_kind_; } + inline scalar_kind_t scalar_kind() const noexcept { return scalar_kind_; } + inline explicit operator bool() const noexcept { return metric_routed_ && metric_ptr_; } + + /** + * @brief Checks fi we've failed to initialized the metric with provided arguments. + * + * It's different from `operator bool()` when it comes to explicitly uninitialized metrics. + * It's a common case, where a NULL state is created only to be overwritten later, when + * we recover an old index state from a file or a network. + */ + inline bool missing() const noexcept { return !bool(*this) && metric_kind_ != metric_kind_t::unknown_k; } - inline std::size_t dimensions() const noexcept { return dimensions_; } - inline metric_kind_t metric_kind() const noexcept { return metric_kind_; } - inline scalar_kind_t scalar_kind() const noexcept { return scalar_kind_; } + inline char const* isa_name() const noexcept { + if (!*this) + return "uninitialized"; - inline char const* isa_name() const noexcept { #if USEARCH_USE_SIMSIMD - switch (isa_kind_) { - case simsimd_cap_serial_k: return "serial"; - case simsimd_cap_arm_neon_k: return "neon"; - case simsimd_cap_arm_sve_k: return "sve"; - case simsimd_cap_x86_avx2_k: return "avx2"; - case simsimd_cap_x86_avx512_k: return "avx512"; - case simsimd_cap_x86_avx2fp16_k: return "avx2+f16"; - case simsimd_cap_x86_avx512fp16_k: return "avx512+f16"; - case simsimd_cap_x86_avx512vpopcntdq_k: return "avx512+popcnt"; - default: return "unknown"; - } + switch (isa_kind_) { + case simsimd_cap_serial_k: return "serial"; + case simsimd_cap_neon_k: return "neon"; + case simsimd_cap_sve_k: return "sve"; + case simsimd_cap_haswell_k: return "haswell"; + case simsimd_cap_skylake_k: return "skylake"; + case simsimd_cap_ice_k: return "ice"; + case simsimd_cap_sapphire_k: return "sapphire"; + default: return "unknown"; + } #endif - return "serial"; - } + return "serial"; + } - inline std::size_t bytes_per_vector() const noexcept { - return divide_round_up(dimensions_ * bits_per_scalar(scalar_kind_)); - } + inline std::size_t bytes_per_vector() const noexcept { + return divide_round_up(dimensions_ * bits_per_scalar(scalar_kind_)); + } - inline std::size_t scalar_words() const noexcept { - return divide_round_up(dimensions_ * bits_per_scalar(scalar_kind_), bits_per_scalar_word(scalar_kind_)); - } + inline std::size_t scalar_words() const noexcept { + return divide_round_up(dimensions_ * bits_per_scalar(scalar_kind_), bits_per_scalar_word(scalar_kind_)); + } -private: + private: #if USEARCH_USE_SIMSIMD - bool configure_with_simsimd(simsimd_capability_t simd_caps) noexcept { - simsimd_metric_kind_t kind = simsimd_metric_unknown_k; - simsimd_datatype_t datatype = simsimd_datatype_unknown_k; - simsimd_capability_t allowed = simsimd_cap_any_k; - switch (metric_kind_) { - case metric_kind_t::ip_k: kind = simsimd_metric_ip_k; break; - case metric_kind_t::cos_k: kind = simsimd_metric_cos_k; break; - case metric_kind_t::l2sq_k: kind = simsimd_metric_l2sq_k; break; - case metric_kind_t::hamming_k: kind = simsimd_metric_hamming_k; break; - case metric_kind_t::tanimoto_k: kind = simsimd_metric_jaccard_k; break; - case metric_kind_t::jaccard_k: kind = simsimd_metric_jaccard_k; break; - default: break; - } - switch (scalar_kind_) { - case scalar_kind_t::f32_k: datatype = simsimd_datatype_f32_k; break; - case scalar_kind_t::f64_k: datatype = simsimd_datatype_f64_k; break; - case scalar_kind_t::f16_k: datatype = simsimd_datatype_f16_k; break; - case scalar_kind_t::i8_k: datatype = simsimd_datatype_i8_k; break; - case scalar_kind_t::b1x8_k: datatype = simsimd_datatype_b8_k; break; - default: break; - } - simsimd_metric_punned_t simd_metric = NULL; - simsimd_capability_t simd_kind = simsimd_cap_any_k; - simsimd_find_metric_punned(kind, datatype, simd_caps, allowed, &simd_metric, &simd_kind); - if (simd_metric == nullptr) - return false; - - std::memcpy(&raw_ptr_, &simd_metric, sizeof(simd_metric)); - isa_kind_ = simd_kind; - return true; - } - bool configure_with_simsimd() noexcept { - static simsimd_capability_t static_capabilities = simsimd_capabilities(); - return configure_with_simsimd(static_capabilities); - } + bool configure_with_simsimd(simsimd_capability_t simd_caps) noexcept { + simsimd_metric_kind_t kind = simsimd_metric_unknown_k; + simsimd_datatype_t datatype = simsimd_datatype_unknown_k; + simsimd_capability_t allowed = simsimd_cap_any_k; + switch (metric_kind_) { + case metric_kind_t::ip_k: kind = simsimd_metric_dot_k; break; + case metric_kind_t::cos_k: kind = simsimd_metric_cos_k; break; + case metric_kind_t::l2sq_k: kind = simsimd_metric_l2sq_k; break; + case metric_kind_t::hamming_k: kind = simsimd_metric_hamming_k; break; + case metric_kind_t::tanimoto_k: kind = simsimd_metric_jaccard_k; break; + case metric_kind_t::jaccard_k: kind = simsimd_metric_jaccard_k; break; + default: break; + } + switch (scalar_kind_) { + case scalar_kind_t::f32_k: datatype = simsimd_datatype_f32_k; break; + case scalar_kind_t::f64_k: datatype = simsimd_datatype_f64_k; break; + case scalar_kind_t::f16_k: datatype = simsimd_datatype_f16_k; break; + case scalar_kind_t::i8_k: datatype = simsimd_datatype_i8_k; break; + case scalar_kind_t::b1x8_k: datatype = simsimd_datatype_b8_k; break; + default: break; + } + simsimd_metric_punned_t simd_metric = NULL; + simsimd_capability_t simd_kind = simsimd_cap_any_k; + simsimd_find_metric_punned(kind, datatype, simd_caps, allowed, &simd_metric, &simd_kind); + if (simd_metric == nullptr) + return false; + + std::memcpy(&metric_ptr_, &simd_metric, sizeof(simd_metric)); + metric_routed_ = metric_kind_ == metric_kind_t::ip_k + ? reinterpret_cast(&metric_punned_t::invoke_simsimd_reverse) + : reinterpret_cast(&metric_punned_t::invoke_simsimd); + isa_kind_ = simd_kind; + return true; + } + bool configure_with_simsimd() noexcept { + static simsimd_capability_t static_capabilities = simsimd_capabilities(); + return configure_with_simsimd(static_capabilities); + } + result_t invoke_simsimd(uptr_t a, uptr_t b) const noexcept { + simsimd_distance_t result; + // Here `reinterpret_cast` raises warning... we know what we are doing! + auto function_pointer = (simsimd_metric_punned_t)(metric_ptr_); + function_pointer(reinterpret_cast(a), reinterpret_cast(b), metric_third_arg_, + &result); + return (result_t)result; + } + result_t invoke_simsimd_reverse(uptr_t a, uptr_t b) const noexcept { return 1 - invoke_simsimd(a, b); } #else - bool configure_with_simsimd() noexcept { return false; } + bool configure_with_simsimd() noexcept { return false; } #endif - - void configure_with_auto_vectorized() noexcept { - switch (metric_kind_) { - case metric_kind_t::ip_k: { - switch (scalar_kind_) { - case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - default: raw_ptr_ = nullptr; break; - } - break; - } - case metric_kind_t::cos_k: { - switch (scalar_kind_) { - case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - default: raw_ptr_ = nullptr; break; - } - break; - } - case metric_kind_t::l2sq_k: { - switch (scalar_kind_) { - case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f16_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - default: raw_ptr_ = nullptr; break; - } - break; - } - case metric_kind_t::pearson_k: { - switch (scalar_kind_) { - case scalar_kind_t::i8_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f16_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - default: raw_ptr_ = nullptr; break; - } - break; - } - case metric_kind_t::haversine_k: { - switch (scalar_kind_) { - case scalar_kind_t::f16_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - default: raw_ptr_ = nullptr; break; - } - break; - } - case metric_kind_t::divergence_k: { - switch (scalar_kind_) { - case scalar_kind_t::f16_k: - raw_ptr_ = (punned_ptr_t)&equidimensional_>; - break; - case scalar_kind_t::f32_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case scalar_kind_t::f64_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - default: raw_ptr_ = nullptr; break; - } - break; - } - case metric_kind_t::jaccard_k: // Equivalent to Tanimoto - case metric_kind_t::tanimoto_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case metric_kind_t::hamming_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - case metric_kind_t::sorensen_k: raw_ptr_ = (punned_ptr_t)&equidimensional_>; break; - default: return; - } - } - - template - inline static result_t equidimensional_( // - punned_arg_t a, punned_arg_t b, // - punned_arg_t a_dimensions, punned_arg_t b_dimensions) noexcept { - using scalar_t = typename typed_at::scalar_t; - (void)b_dimensions; - return typed_at{}((scalar_t const*)a, (scalar_t const*)b, a_dimensions); - } + result_t invoke_array_array_third(uptr_t a, uptr_t b) const noexcept { + auto function_pointer = (metric_array_array_size_t)(metric_ptr_); + result_t result = function_pointer(a, b, metric_third_arg_); + return result; + } + result_t invoke_array_array(uptr_t a, uptr_t b) const noexcept { + auto function_pointer = (metric_array_array_t)(metric_ptr_); + result_t result = function_pointer(a, b); + return result; + } + void configure_with_autovec() noexcept { + switch (metric_kind_) { + case metric_kind_t::ip_k: { + switch (scalar_kind_) { + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::cos_k: { + switch (scalar_kind_) { + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::l2sq_k: { + switch (scalar_kind_) { + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::pearson_k: { + switch (scalar_kind_) { + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::haversine_k: { + switch (scalar_kind_) { + case scalar_kind_t::f16_k: metric_ptr_ = 0; break; //< Having half-precision 2D coordinates is a bit silly. + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::divergence_k: { + switch (scalar_kind_) { + case scalar_kind_t::f16_k: + metric_ptr_ = (uptr_t)&equidimensional_>; + break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::jaccard_k: // Equivalent to Tanimoto + case metric_kind_t::tanimoto_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case metric_kind_t::hamming_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case metric_kind_t::sorensen_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: return; + } + } + + template + inline static result_t equidimensional_(uptr_t a, uptr_t b, uptr_t a_dimensions) noexcept { + using scalar_t = typename typed_at::scalar_t; + return static_cast(typed_at{}((scalar_t const*)a, (scalar_t const*)b, a_dimensions)); + } }; /** @@ -1544,37 +1654,37 @@ class metric_punned_t { */ template // class vectors_view_gt { - using scalar_t = scalar_at; - - scalar_t const* begin_{}; - std::size_t dimensions_{}; - std::size_t count_{}; - std::size_t stride_bytes_{}; - -public: - vectors_view_gt() noexcept = default; - vectors_view_gt(vectors_view_gt const&) noexcept = default; - vectors_view_gt& operator=(vectors_view_gt const&) noexcept = default; - - vectors_view_gt(scalar_t const* begin, std::size_t dimensions, std::size_t count = 1) noexcept - : vectors_view_gt(begin, dimensions, count, dimensions * sizeof(scalar_at)) {} - - vectors_view_gt(scalar_t const* begin, std::size_t dimensions, std::size_t count, std::size_t stride_bytes) noexcept - : begin_(begin), dimensions_(dimensions), count_(count), stride_bytes_(stride_bytes) {} - - explicit operator bool() const noexcept { return begin_; } - std::size_t size() const noexcept { return count_; } - std::size_t dimensions() const noexcept { return dimensions_; } - std::size_t stride() const noexcept { return stride_bytes_; } - scalar_t const* data() const noexcept { return begin_; } - scalar_t const* at(std::size_t i) const noexcept { - return reinterpret_cast(reinterpret_cast(begin_) + i * stride_bytes_); - } + using scalar_t = scalar_at; + + scalar_t const* begin_{}; + std::size_t dimensions_{}; + std::size_t count_{}; + std::size_t stride_bytes_{}; + + public: + vectors_view_gt() noexcept = default; + vectors_view_gt(vectors_view_gt const&) noexcept = default; + vectors_view_gt& operator=(vectors_view_gt const&) noexcept = default; + + vectors_view_gt(scalar_t const* begin, std::size_t dimensions, std::size_t count = 1) noexcept + : vectors_view_gt(begin, dimensions, count, dimensions * sizeof(scalar_at)) {} + + vectors_view_gt(scalar_t const* begin, std::size_t dimensions, std::size_t count, std::size_t stride_bytes) noexcept + : begin_(begin), dimensions_(dimensions), count_(count), stride_bytes_(stride_bytes) {} + + explicit operator bool() const noexcept { return begin_; } + std::size_t size() const noexcept { return count_; } + std::size_t dimensions() const noexcept { return dimensions_; } + std::size_t stride() const noexcept { return stride_bytes_; } + scalar_t const* data() const noexcept { return begin_; } + scalar_t const* at(std::size_t i) const noexcept { + return reinterpret_cast(reinterpret_cast(begin_) + i * stride_bytes_); + } }; struct exact_offset_and_distance_t { - u32_t offset; - f32_t distance; + u32_t offset; + f32_t distance; }; using exact_search_results_t = vectors_view_gt; @@ -1589,97 +1699,97 @@ using exact_search_results_t = vectors_view_gt; */ class exact_search_t { - inline static bool smaller_distance(exact_offset_and_distance_t a, exact_offset_and_distance_t b) noexcept { - return a.distance < b.distance; - } - - using keys_and_distances_t = buffer_gt; - keys_and_distances_t keys_and_distances; - -public: - template - exact_search_results_t operator()( // - vectors_view_gt dataset, vectors_view_gt queries, // - std::size_t wanted, metric_punned_t const& metric, // - executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { - return operator()( // - metric, // - reinterpret_cast(dataset.data()), dataset.size(), dataset.stride(), // - reinterpret_cast(queries.data()), queries.size(), queries.stride(), // - wanted, executor, progress); - } - - template - exact_search_results_t operator()( // - byte_t const* dataset_data, std::size_t dataset_count, std::size_t dataset_stride, // - byte_t const* queries_data, std::size_t queries_count, std::size_t queries_stride, // - std::size_t wanted, metric_punned_t const& metric, executor_at&& executor = executor_at{}, - progress_at&& progress = progress_at{}) { - - // Allocate temporary memory to store the distance matrix - // Previous version didn't need temporary memory, but the performance was much lower. - // In the new design we keep two buffers - original and transposed, as in-place transpositions - // of non-rectangular matrixes is expensive. - std::size_t tasks_count = dataset_count * queries_count; - if (keys_and_distances.size() < tasks_count * 2) - keys_and_distances = keys_and_distances_t(tasks_count * 2); - if (keys_and_distances.size() < tasks_count * 2) - return {}; - - exact_offset_and_distance_t* keys_and_distances_per_dataset = keys_and_distances.data(); - exact_offset_and_distance_t* keys_and_distances_per_query = keys_and_distances_per_dataset + tasks_count; - - // §1. Compute distances in a data-parallel fashion - std::atomic processed{0}; - executor.dynamic(dataset_count, [&](std::size_t thread_idx, std::size_t dataset_idx) { - byte_t const* dataset = dataset_data + dataset_idx * dataset_stride; - for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx) { - byte_t const* query = queries_data + query_idx * queries_stride; - auto distance = metric(dataset, query); - std::size_t task_idx = queries_count * dataset_idx + query_idx; - keys_and_distances_per_dataset[task_idx].offset = static_cast(dataset_idx); - keys_and_distances_per_dataset[task_idx].distance = static_cast(distance); - } - - // It's more efficient in this case to report progress from a single thread - processed += queries_count; - if (thread_idx == 0) - if (!progress(processed.load(), tasks_count)) - return false; - return true; - }); - if (processed.load() != tasks_count) - return {}; - - // §2. Transpose in a single thread to avoid contention writing into the same memory buffers - for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx) { - for (std::size_t dataset_idx = 0; dataset_idx != dataset_count; ++dataset_idx) { - std::size_t from_idx = queries_count * dataset_idx + query_idx; - std::size_t to_idx = dataset_count * query_idx + dataset_idx; - keys_and_distances_per_query[to_idx] = keys_and_distances_per_dataset[from_idx]; - } - } - - // §3. Partial-sort every query result - executor.fixed(queries_count, [&](std::size_t, std::size_t query_idx) { - auto start = keys_and_distances_per_query + dataset_count * query_idx; - if (wanted > 1) { - // TODO: Consider alternative sorting approaches - // radix_sort(start, start + dataset_count, wanted); - // std::sort(start, start + dataset_count, &smaller_distance); - std::partial_sort(start, start + wanted, start + dataset_count, &smaller_distance); - } else { - auto min_it = std::min_element(start, start + dataset_count, &smaller_distance); - if (min_it != start) - std::swap(*min_it, *start); - } - }); - - // At the end report the latest numbers, because the reporter thread may be finished earlier - progress(tasks_count, tasks_count); - return {keys_and_distances_per_query, wanted, queries_count, - dataset_count * sizeof(exact_offset_and_distance_t)}; - } + inline static bool smaller_distance(exact_offset_and_distance_t a, exact_offset_and_distance_t b) noexcept { + return a.distance < b.distance; + } + + using keys_and_distances_t = buffer_gt; + keys_and_distances_t keys_and_distances; + + public: + template + exact_search_results_t operator()( // + vectors_view_gt dataset, vectors_view_gt queries, // + std::size_t wanted, metric_punned_t const& metric, // + executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + return operator()( // + metric, // + reinterpret_cast(dataset.data()), dataset.size(), dataset.stride(), // + reinterpret_cast(queries.data()), queries.size(), queries.stride(), // + wanted, executor, progress); + } + + template + exact_search_results_t operator()( // + byte_t const* dataset_data, std::size_t dataset_count, std::size_t dataset_stride, // + byte_t const* queries_data, std::size_t queries_count, std::size_t queries_stride, // + std::size_t wanted, metric_punned_t const& metric, executor_at&& executor = executor_at{}, + progress_at&& progress = progress_at{}) { + + // Allocate temporary memory to store the distance matrix + // Previous version didn't need temporary memory, but the performance was much lower. + // In the new design we keep two buffers - original and transposed, as in-place transpositions + // of non-rectangular matrixes is expensive. + std::size_t tasks_count = dataset_count * queries_count; + if (keys_and_distances.size() < tasks_count * 2) + keys_and_distances = keys_and_distances_t(tasks_count * 2); + if (keys_and_distances.size() < tasks_count * 2) + return {}; + + exact_offset_and_distance_t* keys_and_distances_per_dataset = keys_and_distances.data(); + exact_offset_and_distance_t* keys_and_distances_per_query = keys_and_distances_per_dataset + tasks_count; + + // §1. Compute distances in a data-parallel fashion + std::atomic processed{0}; + executor.dynamic(dataset_count, [&](std::size_t thread_idx, std::size_t dataset_idx) { + byte_t const* dataset = dataset_data + dataset_idx * dataset_stride; + for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx) { + byte_t const* query = queries_data + query_idx * queries_stride; + auto distance = metric(dataset, query); + std::size_t task_idx = queries_count * dataset_idx + query_idx; + keys_and_distances_per_dataset[task_idx].offset = static_cast(dataset_idx); + keys_and_distances_per_dataset[task_idx].distance = static_cast(distance); + } + + // It's more efficient in this case to report progress from a single thread + processed += queries_count; + if (thread_idx == 0) + if (!progress(processed.load(), tasks_count)) + return false; + return true; + }); + if (processed.load() != tasks_count) + return {}; + + // §2. Transpose in a single thread to avoid contention writing into the same memory buffers + for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx) { + for (std::size_t dataset_idx = 0; dataset_idx != dataset_count; ++dataset_idx) { + std::size_t from_idx = queries_count * dataset_idx + query_idx; + std::size_t to_idx = dataset_count * query_idx + dataset_idx; + keys_and_distances_per_query[to_idx] = keys_and_distances_per_dataset[from_idx]; + } + } + + // §3. Partial-sort every query result + executor.fixed(queries_count, [&](std::size_t, std::size_t query_idx) { + auto start = keys_and_distances_per_query + dataset_count * query_idx; + if (wanted > 1) { + // TODO: Consider alternative sorting approaches + // radix_sort(start, start + dataset_count, wanted); + // std::sort(start, start + dataset_count, &smaller_distance); + std::partial_sort(start, start + wanted, start + dataset_count, &smaller_distance); + } else { + auto min_it = std::min_element(start, start + dataset_count, &smaller_distance); + if (min_it != start) + std::swap(*min_it, *start); + } + }); + + // At the end report the latest numbers, because the reporter thread may be finished earlier + progress(tasks_count, tasks_count); + return {keys_and_distances_per_query, wanted, queries_count, + dataset_count * sizeof(exact_offset_and_distance_t)}; + } }; /** @@ -1697,517 +1807,517 @@ class exact_search_t { */ template > class flat_hash_multi_set_gt { -public: - using element_t = element_at; - using hash_t = hash_at; - using equals_t = equals_at; - using allocator_t = allocator_at; - - static constexpr std::size_t slots_per_bucket() { return 64; } - static constexpr std::size_t bytes_per_bucket() { - return slots_per_bucket() * sizeof(element_t) + sizeof(bucket_header_t); - } - -private: - struct bucket_header_t { - std::uint64_t populated{}; - std::uint64_t deleted{}; - }; - char* data_ = nullptr; - std::size_t buckets_ = 0; - std::size_t populated_slots_ = 0; - /// @brief Number of slots - std::size_t capacity_slots_ = 0; - - struct slot_ref_t { - bucket_header_t& header; - std::uint64_t mask; - element_t& element; - }; - - slot_ref_t slot_ref(char* data, std::size_t slot_index) const noexcept { - std::size_t bucket_index = slot_index / slots_per_bucket(); - std::size_t in_bucket_index = slot_index % slots_per_bucket(); - auto bucket_pointer = data + bytes_per_bucket() * bucket_index; - auto slot_pointer = bucket_pointer + sizeof(bucket_header_t) + sizeof(element_t) * in_bucket_index; - return { - *reinterpret_cast(bucket_pointer), - static_cast(1ull) << in_bucket_index, - *reinterpret_cast(slot_pointer), - }; - } - - slot_ref_t slot_ref(std::size_t slot_index) const noexcept { return slot_ref(data_, slot_index); } - - bool populate_slot(slot_ref_t slot, element_t const& new_element) { - if (slot.header.populated & slot.mask) { - slot.element = new_element; - slot.header.deleted &= ~slot.mask; - return false; - } else { - new (&slot.element) element_t(new_element); - slot.header.populated |= slot.mask; - return true; - } - } - -public: - std::size_t size() const noexcept { return populated_slots_; } - std::size_t capacity() const noexcept { return capacity_slots_; } - - flat_hash_multi_set_gt() noexcept {} - ~flat_hash_multi_set_gt() noexcept { reset(); } - - flat_hash_multi_set_gt(flat_hash_multi_set_gt const& other) { - - // On Windows allocating a zero-size array would fail - if (!other.buckets_) { - reset(); - return; - } - - // Allocate new memory - data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); - if (!data_) - throw std::bad_alloc(); - - // Copy metadata - buckets_ = other.buckets_; - populated_slots_ = other.populated_slots_; - capacity_slots_ = other.capacity_slots_; - - // Initialize new buckets to empty - std::memset(data_, 0, buckets_ * bytes_per_bucket()); - - // Copy elements and bucket headers - for (std::size_t i = 0; i < capacity_slots_; ++i) { - slot_ref_t old_slot = other.slot_ref(i); - if ((old_slot.header.populated & old_slot.mask) && !(old_slot.header.deleted & old_slot.mask)) { - slot_ref_t new_slot = slot_ref(i); - populate_slot(new_slot, old_slot.element); - } - } - } - - flat_hash_multi_set_gt& operator=(flat_hash_multi_set_gt const& other) { - - // On Windows allocating a zero-size array would fail - if (!other.buckets_) { - reset(); - return *this; - } - - // Handle self-assignment - if (this == &other) - return *this; - - // Clear existing data - clear(); - if (data_) - allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); - - // Allocate new memory - data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); - if (!data_) - throw std::bad_alloc(); - - // Copy metadata - buckets_ = other.buckets_; - populated_slots_ = other.populated_slots_; - capacity_slots_ = other.capacity_slots_; - - // Initialize new buckets to empty - std::memset(data_, 0, buckets_ * bytes_per_bucket()); - - // Copy elements and bucket headers - for (std::size_t i = 0; i < capacity_slots_; ++i) { - slot_ref_t old_slot = other.slot_ref(i); - if ((old_slot.header.populated & old_slot.mask) && !(old_slot.header.deleted & old_slot.mask)) { - slot_ref_t new_slot = slot_ref(i); - populate_slot(new_slot, old_slot.element); - } - } - - return *this; - } - - void clear() noexcept { - // Call the destructors - for (std::size_t i = 0; i < capacity_slots_; ++i) { - slot_ref_t slot = slot_ref(i); - if ((slot.header.populated & slot.mask) & (~slot.header.deleted & slot.mask)) - slot.element.~element_t(); - } - - // Reset populated slots count - if (data_) - std::memset(data_, 0, buckets_ * bytes_per_bucket()); - populated_slots_ = 0; - } - - void reset() noexcept { - clear(); // Clear all elements - if (data_) - allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); - buckets_ = 0; - populated_slots_ = 0; - capacity_slots_ = 0; - } - - bool try_reserve(std::size_t capacity) noexcept { - if (capacity * 3u <= capacity_slots_ * 2u) - return true; - - // Calculate new sizes - std::size_t new_slots = ceil2((capacity * 3ul) / 2ul); - std::size_t new_buckets = divide_round_up(new_slots); - new_slots = new_buckets * slots_per_bucket(); // This must be a power of two! - std::size_t new_bytes = new_buckets * bytes_per_bucket(); - - // Allocate new memory - char* new_data = (char*)allocator_t{}.allocate(new_bytes); - if (!new_data) - return false; - - // Initialize new buckets to empty - std::memset(new_data, 0, new_bytes); - - // Rehash and copy existing elements to new_data - hash_t hasher; - for (std::size_t i = 0; i < capacity_slots_; ++i) { - slot_ref_t old_slot = slot_ref(i); - if ((~old_slot.header.populated & old_slot.mask) | (old_slot.header.deleted & old_slot.mask)) - continue; - - // Rehash - std::size_t hash_value = hasher(old_slot.element); - std::size_t new_slot_index = hash_value & (new_slots - 1); - - // Linear probing to find an empty slot in new_data - while (true) { - slot_ref_t new_slot = slot_ref(new_data, new_slot_index); - if (!(new_slot.header.populated & new_slot.mask) || (new_slot.header.deleted & new_slot.mask)) { - populate_slot(new_slot, std::move(old_slot.element)); - new_slot.header.populated |= new_slot.mask; - break; - } - new_slot_index = (new_slot_index + 1) & (new_slots - 1); - } - } - - // Deallocate old data and update pointers and sizes - if (data_) - allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); - data_ = new_data; - buckets_ = new_buckets; - capacity_slots_ = new_slots; - - return true; - } - - template class equal_iterator_gt { - public: - using iterator_category = std::forward_iterator_tag; - using value_type = element_t; - using difference_type = std::ptrdiff_t; - using pointer = element_t*; - using reference = element_t&; - - equal_iterator_gt(std::size_t index, flat_hash_multi_set_gt* parent, query_at const& query, - equals_t const& equals) - : index_(index), parent_(parent), query_(query), equals_(equals) {} - - // Pre-increment - equal_iterator_gt& operator++() { - do { - index_ = (index_ + 1) & (parent_->capacity_slots_ - 1); - } while (!equals_(parent_->slot_ref(index_).element, query_) && - (parent_->slot_ref(index_).header.populated & parent_->slot_ref(index_).mask)); - return *this; - } - - equal_iterator_gt operator++(int) { - equal_iterator_gt temp = *this; - ++(*this); - return temp; - } - - reference operator*() { return parent_->slot_ref(index_).element; } - pointer operator->() { return &parent_->slot_ref(index_).element; } - bool operator!=(equal_iterator_gt const& other) const { return !(*this == other); } - bool operator==(equal_iterator_gt const& other) const { - return index_ == other.index_ && parent_ == other.parent_; - } - - private: - std::size_t index_; - flat_hash_multi_set_gt* parent_; - query_at query_; // Store the query object - equals_t equals_; // Store the equals functor - }; - - /** + public: + using element_t = element_at; + using hash_t = hash_at; + using equals_t = equals_at; + using allocator_t = allocator_at; + + static constexpr std::size_t slots_per_bucket() { return 64; } + static constexpr std::size_t bytes_per_bucket() { + return slots_per_bucket() * sizeof(element_t) + sizeof(bucket_header_t); + } + + private: + struct bucket_header_t { + std::uint64_t populated{}; + std::uint64_t deleted{}; + }; + char* data_ = nullptr; + std::size_t buckets_ = 0; + std::size_t populated_slots_ = 0; + /// @brief Number of slots + std::size_t capacity_slots_ = 0; + + struct slot_ref_t { + bucket_header_t& header; + std::uint64_t mask; + element_t& element; + }; + + slot_ref_t slot_ref(char* data, std::size_t slot_index) const noexcept { + std::size_t bucket_index = slot_index / slots_per_bucket(); + std::size_t in_bucket_index = slot_index % slots_per_bucket(); + auto bucket_pointer = data + bytes_per_bucket() * bucket_index; + auto slot_pointer = bucket_pointer + sizeof(bucket_header_t) + sizeof(element_t) * in_bucket_index; + return { + *reinterpret_cast(bucket_pointer), + static_cast(1ull) << in_bucket_index, + *reinterpret_cast(slot_pointer), + }; + } + + slot_ref_t slot_ref(std::size_t slot_index) const noexcept { return slot_ref(data_, slot_index); } + + bool populate_slot(slot_ref_t slot, element_t const& new_element) { + if (slot.header.populated & slot.mask) { + slot.element = new_element; + slot.header.deleted &= ~slot.mask; + return false; + } else { + new (&slot.element) element_t(new_element); + slot.header.populated |= slot.mask; + return true; + } + } + + public: + std::size_t size() const noexcept { return populated_slots_; } + std::size_t capacity() const noexcept { return capacity_slots_; } + + flat_hash_multi_set_gt() noexcept {} + ~flat_hash_multi_set_gt() noexcept { reset(); } + + flat_hash_multi_set_gt(flat_hash_multi_set_gt const& other) { + + // On Windows allocating a zero-size array would fail + if (!other.buckets_) { + reset(); + return; + } + + // Allocate new memory + data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); + if (!data_) + throw std::bad_alloc(); + + // Copy metadata + buckets_ = other.buckets_; + populated_slots_ = other.populated_slots_; + capacity_slots_ = other.capacity_slots_; + + // Initialize new buckets to empty + std::memset(data_, 0, buckets_ * bytes_per_bucket()); + + // Copy elements and bucket headers + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t old_slot = other.slot_ref(i); + if ((old_slot.header.populated & old_slot.mask) && !(old_slot.header.deleted & old_slot.mask)) { + slot_ref_t new_slot = slot_ref(i); + populate_slot(new_slot, old_slot.element); + } + } + } + + flat_hash_multi_set_gt& operator=(flat_hash_multi_set_gt const& other) { + + // On Windows allocating a zero-size array would fail + if (!other.buckets_) { + reset(); + return *this; + } + + // Handle self-assignment + if (this == &other) + return *this; + + // Clear existing data + clear(); + if (data_) + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); + + // Allocate new memory + data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); + if (!data_) + throw std::bad_alloc(); + + // Copy metadata + buckets_ = other.buckets_; + populated_slots_ = other.populated_slots_; + capacity_slots_ = other.capacity_slots_; + + // Initialize new buckets to empty + std::memset(data_, 0, buckets_ * bytes_per_bucket()); + + // Copy elements and bucket headers + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t old_slot = other.slot_ref(i); + if ((old_slot.header.populated & old_slot.mask) && !(old_slot.header.deleted & old_slot.mask)) { + slot_ref_t new_slot = slot_ref(i); + populate_slot(new_slot, old_slot.element); + } + } + + return *this; + } + + void clear() noexcept { + // Call the destructors + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t slot = slot_ref(i); + if ((slot.header.populated & slot.mask) & (~slot.header.deleted & slot.mask)) + slot.element.~element_t(); + } + + // Reset populated slots count + if (data_) + std::memset(data_, 0, buckets_ * bytes_per_bucket()); + populated_slots_ = 0; + } + + void reset() noexcept { + clear(); // Clear all elements + if (data_) + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); + buckets_ = 0; + populated_slots_ = 0; + capacity_slots_ = 0; + } + + bool try_reserve(std::size_t capacity) noexcept { + if (capacity * 3u <= capacity_slots_ * 2u) + return true; + + // Calculate new sizes + std::size_t new_slots = ceil2((capacity * 3ul) / 2ul); + std::size_t new_buckets = divide_round_up(new_slots); + new_slots = new_buckets * slots_per_bucket(); // This must be a power of two! + std::size_t new_bytes = new_buckets * bytes_per_bucket(); + + // Allocate new memory + char* new_data = (char*)allocator_t{}.allocate(new_bytes); + if (!new_data) + return false; + + // Initialize new buckets to empty + std::memset(new_data, 0, new_bytes); + + // Rehash and copy existing elements to new_data + hash_t hasher; + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t old_slot = slot_ref(i); + if ((~old_slot.header.populated & old_slot.mask) | (old_slot.header.deleted & old_slot.mask)) + continue; + + // Rehash + std::size_t hash_value = hasher(old_slot.element); + std::size_t new_slot_index = hash_value & (new_slots - 1); + + // Linear probing to find an empty slot in new_data + while (true) { + slot_ref_t new_slot = slot_ref(new_data, new_slot_index); + if (!(new_slot.header.populated & new_slot.mask) || (new_slot.header.deleted & new_slot.mask)) { + populate_slot(new_slot, std::move(old_slot.element)); + new_slot.header.populated |= new_slot.mask; + break; + } + new_slot_index = (new_slot_index + 1) & (new_slots - 1); + } + } + + // Deallocate old data and update pointers and sizes + if (data_) + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); + data_ = new_data; + buckets_ = new_buckets; + capacity_slots_ = new_slots; + + return true; + } + + template class equal_iterator_gt { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = element_t*; + using reference = element_t&; + + equal_iterator_gt(std::size_t index, flat_hash_multi_set_gt* parent, query_at const& query, + equals_t const& equals) + : index_(index), parent_(parent), query_(query), equals_(equals) {} + + // Pre-increment + equal_iterator_gt& operator++() { + do { + index_ = (index_ + 1) & (parent_->capacity_slots_ - 1); + } while (!equals_(parent_->slot_ref(index_).element, query_) && + (parent_->slot_ref(index_).header.populated & parent_->slot_ref(index_).mask)); + return *this; + } + + equal_iterator_gt operator++(int) { + equal_iterator_gt temp = *this; + ++(*this); + return temp; + } + + reference operator*() { return parent_->slot_ref(index_).element; } + pointer operator->() { return &parent_->slot_ref(index_).element; } + bool operator!=(equal_iterator_gt const& other) const { return !(*this == other); } + bool operator==(equal_iterator_gt const& other) const { + return index_ == other.index_ && parent_ == other.parent_; + } + + private: + std::size_t index_; + flat_hash_multi_set_gt* parent_; + query_at query_; // Store the query object + equals_t equals_; // Store the equals functor + }; + + /** * @brief Returns an iterator range of all elements matching the given query. * * Technically, the second iterator points to the first empty slot after a * range of equal values and non-equal values with similar hashes. - */ - template - std::pair, equal_iterator_gt> - equal_range(query_at const& query) const noexcept { - - equals_t equals; - auto this_ptr = const_cast(this); - auto end = equal_iterator_gt(capacity_slots_, this_ptr, query, equals); - if (!capacity_slots_) - return {end, end}; - - hash_t hasher; - std::size_t hash_value = hasher(query); - std::size_t first_equal_index = hash_value & (capacity_slots_ - 1); - std::size_t const start_index = first_equal_index; - - // Linear probing to find the first equal element - do { - slot_ref_t slot = slot_ref(first_equal_index); - if (slot.header.populated & ~slot.header.deleted & slot.mask) { - if (equals(slot.element, query)) - break; - } - // Stop if we find an empty slot - else if (~slot.header.populated & slot.mask) - return {end, end}; - - // Move to the next slot - first_equal_index = (first_equal_index + 1) & (capacity_slots_ - 1); - } while (first_equal_index != start_index); - - // If no matching element was found, return end iterators - if (first_equal_index == capacity_slots_) - return {end, end}; - - // Start from the first matching element and find the end of the populated range - std::size_t first_empty_index = first_equal_index; - do { - first_empty_index = (first_empty_index + 1) & (capacity_slots_ - 1); - slot_ref_t slot = slot_ref(first_empty_index); - - // If we find an empty slot, this is our end - if (~slot.header.populated & slot.mask) - break; - } while (first_empty_index != start_index); - - return {equal_iterator_gt(first_equal_index, this_ptr, query, equals), - equal_iterator_gt(first_empty_index, this_ptr, query, equals)}; - } - - template bool pop_first(similar_at&& query, element_t& popped_value) noexcept { - - if (!capacity_slots_) - return false; - - hash_t hasher; - equals_t equals; - std::size_t hash_value = hasher(query); - std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 - std::size_t start_index = slot_index; // To detect loop in probing - - // Linear probing to find the first match - do { - slot_ref_t slot = slot_ref(slot_index); - if (slot.header.populated & slot.mask) { - if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) { - // Found a match, mark as deleted - slot.header.deleted |= slot.mask; - --populated_slots_; - popped_value = slot.element; - return true; // Successfully removed - } - } else { - // Stop if we find an empty slot - break; - } - - // Move to the next slot - slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 - } while (slot_index != start_index); - - return false; // No match found - } - - template std::size_t erase(similar_at&& query) noexcept { - - if (!capacity_slots_) - return 0; - - hash_t hasher; - equals_t equals; - std::size_t hash_value = hasher(query); - std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 - std::size_t const start_index = slot_index; // To detect loop in probing - std::size_t count = 0; // Count of elements removed - - // Linear probing to find all matches - do { - slot_ref_t slot = slot_ref(slot_index); - if (slot.header.populated & slot.mask) { - if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) { - // Found a match, mark as deleted - slot.header.deleted |= slot.mask; - --populated_slots_; - ++count; // Increment count of elements removed - } - } else { - // Stop if we find an empty slot - break; - } - - // Move to the next slot - slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 - } while (slot_index != start_index); - - return count; // Return the number of elements removed - } - - template element_t const* find(similar_at&& query) const noexcept { - - if (!capacity_slots_) - return nullptr; - - hash_t hasher; - equals_t equals; - std::size_t hash_value = hasher(query); - std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 - std::size_t start_index = slot_index; // To detect loop in probing - - // Linear probing to find the first match - do { - slot_ref_t slot = slot_ref(slot_index); - if (slot.header.populated & slot.mask) { - if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) - return &slot.element; // Found a match, return pointer to the element - } else { - // Stop if we find an empty slot - break; - } - - // Move to the next slot - slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 - } while (slot_index != start_index); - - return nullptr; // No match found - } - - element_t const* end() const noexcept { return nullptr; } - - template void for_each(func_at&& func) const { - for (std::size_t bucket_index = 0; bucket_index < buckets_; ++bucket_index) { - auto bucket_pointer = data_ + bytes_per_bucket() * bucket_index; - bucket_header_t& header = *reinterpret_cast(bucket_pointer); - std::uint64_t populated = header.populated; - std::uint64_t deleted = header.deleted; - - // Iterate through slots in the bucket - for (std::size_t in_bucket_index = 0; in_bucket_index < slots_per_bucket(); ++in_bucket_index) { - std::uint64_t mask = std::uint64_t(1ull) << in_bucket_index; - - // Check if the slot is populated and not deleted - if ((populated & ~deleted) & mask) { - auto slot_pointer = bucket_pointer + sizeof(bucket_header_t) + sizeof(element_t) * in_bucket_index; - element_t const& element = *reinterpret_cast(slot_pointer); - func(element); - } - } - } - } - - template std::size_t count(similar_at&& query) const noexcept { - - if (!capacity_slots_) - return 0; - - hash_t hasher; - equals_t equals; - std::size_t hash_value = hasher(query); - std::size_t slot_index = hash_value & (capacity_slots_ - 1); - std::size_t start_index = slot_index; // To detect loop in probing - std::size_t count = 0; - - // Linear probing to find the range - do { - slot_ref_t slot = slot_ref(slot_index); - if ((slot.header.populated & slot.mask) && (~slot.header.deleted & slot.mask)) { - if (equals(slot.element, query)) - ++count; - } else if (~slot.header.populated & slot.mask) { - // Stop if we find an empty slot - break; - } - - // Move to the next slot - slot_index = (slot_index + 1) & (capacity_slots_ - 1); - } while (slot_index != start_index); - - return count; - } - - template bool contains(similar_at&& query) const noexcept { - - if (!capacity_slots_) - return false; - - hash_t hasher; - equals_t equals; - std::size_t hash_value = hasher(query); - std::size_t slot_index = hash_value & (capacity_slots_ - 1); - std::size_t start_index = slot_index; // To detect loop in probing - - // Linear probing to find the first match - do { - slot_ref_t slot = slot_ref(slot_index); - if (slot.header.populated & slot.mask) { - if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) - return true; // Found a match, exit early - } else - // Stop if we find an empty slot - break; - - // Move to the next slot - slot_index = (slot_index + 1) & (capacity_slots_ - 1); - } while (slot_index != start_index); - - return false; // No match found - } - - void reserve(std::size_t capacity) { - if (!try_reserve(capacity)) - throw std::bad_alloc(); - } - - bool try_emplace(element_t const& element) noexcept { - // Check if we need to resize - if (populated_slots_ * 3u >= capacity_slots_ * 2u) - if (!try_reserve(populated_slots_ + 1)) - return false; - - hash_t hasher; - std::size_t hash_value = hasher(element); - std::size_t slot_index = hash_value & (capacity_slots_ - 1); - - // Linear probing - while (true) { - slot_ref_t slot = slot_ref(slot_index); - if ((~slot.header.populated & slot.mask) | (slot.header.deleted & slot.mask)) { - // Found an empty or deleted slot - populate_slot(slot, element); - ++populated_slots_; - return true; - } - // Move to the next slot - slot_index = (slot_index + 1) & (capacity_slots_ - 1); - } - } + */ + template + std::pair, equal_iterator_gt> + equal_range(query_at const& query) const noexcept { + + equals_t equals; + auto this_ptr = const_cast(this); + auto end = equal_iterator_gt(capacity_slots_, this_ptr, query, equals); + if (!capacity_slots_) + return {end, end}; + + hash_t hasher; + std::size_t hash_value = hasher(query); + std::size_t first_equal_index = hash_value & (capacity_slots_ - 1); + std::size_t const start_index = first_equal_index; + + // Linear probing to find the first equal element + do { + slot_ref_t slot = slot_ref(first_equal_index); + if (slot.header.populated & ~slot.header.deleted & slot.mask) { + if (equals(slot.element, query)) + break; + } + // Stop if we find an empty slot + else if (~slot.header.populated & slot.mask) + return {end, end}; + + // Move to the next slot + first_equal_index = (first_equal_index + 1) & (capacity_slots_ - 1); + } while (first_equal_index != start_index); + + // If no matching element was found, return end iterators + if (first_equal_index == capacity_slots_) + return {end, end}; + + // Start from the first matching element and find the end of the populated range + std::size_t first_empty_index = first_equal_index; + do { + first_empty_index = (first_empty_index + 1) & (capacity_slots_ - 1); + slot_ref_t slot = slot_ref(first_empty_index); + + // If we find an empty slot, this is our end + if (~slot.header.populated & slot.mask) + break; + } while (first_empty_index != start_index); + + return {equal_iterator_gt(first_equal_index, this_ptr, query, equals), + equal_iterator_gt(first_empty_index, this_ptr, query, equals)}; + } + + template bool pop_first(similar_at&& query, element_t& popped_value) noexcept { + + if (!capacity_slots_) + return false; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + std::size_t start_index = slot_index; // To detect loop in probing + + // Linear probing to find the first match + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) { + // Found a match, mark as deleted + slot.header.deleted |= slot.mask; + --populated_slots_; + popped_value = slot.element; + return true; // Successfully removed + } + } else { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + } while (slot_index != start_index); + + return false; // No match found + } + + template std::size_t erase(similar_at&& query) noexcept { + + if (!capacity_slots_) + return 0; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + std::size_t const start_index = slot_index; // To detect loop in probing + std::size_t count = 0; // Count of elements removed + + // Linear probing to find all matches + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) { + // Found a match, mark as deleted + slot.header.deleted |= slot.mask; + --populated_slots_; + ++count; // Increment count of elements removed + } + } else { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + } while (slot_index != start_index); + + return count; // Return the number of elements removed + } + + template element_t const* find(similar_at&& query) const noexcept { + + if (!capacity_slots_) + return nullptr; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + std::size_t start_index = slot_index; // To detect loop in probing + + // Linear probing to find the first match + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) + return &slot.element; // Found a match, return pointer to the element + } else { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + } while (slot_index != start_index); + + return nullptr; // No match found + } + + element_t const* end() const noexcept { return nullptr; } + + template void for_each(func_at&& func) const { + for (std::size_t bucket_index = 0; bucket_index < buckets_; ++bucket_index) { + auto bucket_pointer = data_ + bytes_per_bucket() * bucket_index; + bucket_header_t& header = *reinterpret_cast(bucket_pointer); + std::uint64_t populated = header.populated; + std::uint64_t deleted = header.deleted; + + // Iterate through slots in the bucket + for (std::size_t in_bucket_index = 0; in_bucket_index < slots_per_bucket(); ++in_bucket_index) { + std::uint64_t mask = std::uint64_t(1ull) << in_bucket_index; + + // Check if the slot is populated and not deleted + if ((populated & ~deleted) & mask) { + auto slot_pointer = bucket_pointer + sizeof(bucket_header_t) + sizeof(element_t) * in_bucket_index; + element_t const& element = *reinterpret_cast(slot_pointer); + func(element); + } + } + } + } + + template std::size_t count(similar_at&& query) const noexcept { + + if (!capacity_slots_) + return 0; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); + std::size_t start_index = slot_index; // To detect loop in probing + std::size_t count = 0; + + // Linear probing to find the range + do { + slot_ref_t slot = slot_ref(slot_index); + if ((slot.header.populated & slot.mask) && (~slot.header.deleted & slot.mask)) { + if (equals(slot.element, query)) + ++count; + } else if (~slot.header.populated & slot.mask) { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); + } while (slot_index != start_index); + + return count; + } + + template bool contains(similar_at&& query) const noexcept { + + if (!capacity_slots_) + return false; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); + std::size_t start_index = slot_index; // To detect loop in probing + + // Linear probing to find the first match + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) + return true; // Found a match, exit early + } else + // Stop if we find an empty slot + break; + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); + } while (slot_index != start_index); + + return false; // No match found + } + + void reserve(std::size_t capacity) { + if (!try_reserve(capacity)) + throw std::bad_alloc(); + } + + bool try_emplace(element_t const& element) noexcept { + // Check if we need to resize + if (populated_slots_ * 3u >= capacity_slots_ * 2u) + if (!try_reserve(populated_slots_ + 1)) + return false; + + hash_t hasher; + std::size_t hash_value = hasher(element); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); + + // Linear probing + while (true) { + slot_ref_t slot = slot_ref(slot_index); + if ((~slot.header.populated & slot.mask) | (slot.header.deleted & slot.mask)) { + // Found an empty or deleted slot + populate_slot(slot, element); + ++populated_slots_; + return true; + } + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); + } + } }; } // namespace usearch -} // namespace unum \ No newline at end of file +} // namespace unum From 755276c5a3e5f08dbf30947e62403e26102bd4cf Mon Sep 17 00:00:00 2001 From: Max Gabrielsson Date: Wed, 26 Jun 2024 15:38:10 +0200 Subject: [PATCH 2/2] fix key --- src/include/usearch/index_dense.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/include/usearch/index_dense.hpp b/src/include/usearch/index_dense.hpp index 563e039..3e811e3 100644 --- a/src/include/usearch/index_dense.hpp +++ b/src/include/usearch/index_dense.hpp @@ -726,7 +726,7 @@ class index_dense_gt { cluster_config.thread = lock.thread_id; cluster_config.expansion = config_.expansion_search; metric_proxy_t metric{*this}; - auto &free_key_ = this->free_key; + auto &free_key_ = this->free_key_; auto allow = [&free_key_](member_cref_t const& member) noexcept { return member.key != free_key_; };