From 253415cceebda4012e5024216ecea279a650aee8 Mon Sep 17 00:00:00 2001 From: Adib Mohsin Date: Sat, 14 Feb 2026 13:38:16 +0600 Subject: [PATCH 1/2] zig high level bindings --- zig/.gitignore | 6 + zig/LICENSE | 7 + zig/README.md | 499 +++ zig/build.zig | 74 + zig/build.zig.zon | 13 + zig/src/root.zig | 757 ++++ zig/usearch/include/index.hpp | 4548 +++++++++++++++++++++++++ zig/usearch/include/index_dense.hpp | 2273 ++++++++++++ zig/usearch/include/index_plugins.hpp | 3033 +++++++++++++++++ zig/usearch/include/lib.cpp | 507 +++ zig/usearch/include/usearch.h | 487 +++ 11 files changed, 12204 insertions(+) create mode 100644 zig/.gitignore create mode 100644 zig/LICENSE create mode 100644 zig/README.md create mode 100644 zig/build.zig create mode 100644 zig/build.zig.zon create mode 100644 zig/src/root.zig create mode 100644 zig/usearch/include/index.hpp create mode 100644 zig/usearch/include/index_dense.hpp create mode 100644 zig/usearch/include/index_plugins.hpp create mode 100644 zig/usearch/include/lib.cpp create mode 100644 zig/usearch/include/usearch.h diff --git a/zig/.gitignore b/zig/.gitignore new file mode 100644 index 000000000..f038be335 --- /dev/null +++ b/zig/.gitignore @@ -0,0 +1,6 @@ +zig-cache/ +zig-out/ +.zig-cache/ + +.DS_Store +**/.DS_Store diff --git a/zig/LICENSE b/zig/LICENSE new file mode 100644 index 000000000..66e2c93a0 --- /dev/null +++ b/zig/LICENSE @@ -0,0 +1,7 @@ +Copyright 2025 Adib Mohsin + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/zig/README.md b/zig/README.md new file mode 100644 index 000000000..bed6d5155 --- /dev/null +++ b/zig/README.md @@ -0,0 +1,499 @@ +# USearch Zig + +High-performance Zig bindings for [USearch](https://github.com/unum-cloud/usearch) - a smaller, faster, and more scalable vector search library for approximate nearest neighbor (ANN) search. + +## Installation + +### Add as a dependency + +Add USearch Zig to your project using the Zig package manager: + +```bash +zig fetch --save git+https://github.com/pacifio/usearch-zig#main +``` + +### Configure your `build.zig` + +Add the dependency to your executable or library: + +```zig +const usearch_dep = b.dependency("usearch_zig", .{ + .target = target, + .optimize = optimize, +}); + +exe.root_module.addImport("usearch_zig", usearch_dep.module("usearch_zig")); + +exe.root_module.link_libc = true; +exe.root_module.link_libcpp = true; +``` + +## Quick Start + +```zig +const std = @import("std"); +const usearch = @import("usearch_zig"); + +pub fn main() !void { + const allocator = std.heap.page_allocator; + + // Create an index for 3-dimensional vectors + const config = usearch.IndexConfig.default(3); + var index = try usearch.Index.init(allocator, config); + defer index.deinit(); + + // Add some vectors + var vec1 = [_]f32{ 1.0, 2.0, 3.0 }; + var vec2 = [_]f32{ 4.0, 5.0, 6.0 }; + + try index.add(1, &vec1); + try index.add(2, &vec2); + + // Search for similar vectors + const results = try index.search(&vec1, 2); + defer allocator.free(results); + + for (results) |result| { + std.debug.print("Key: {}, Distance: {}\n", .{ + result.key, + result.distance, + }); + } +} +``` + +## Usage Examples + +### Basic Index Operations + +Create an index, add vectors, and perform similarity searches: + +```zig +const std = @import("std"); +const usearch = @import("usearch_zig"); + +pub fn basicExample() !void { + const allocator = std.heap.page_allocator; + + // Create an index for 3-dimensional vectors + const config = usearch.IndexConfig.default(3); + var index = try usearch.Index.init(allocator, config); + defer index.deinit(); + + // Check initial size + const size_initial = try index.len(); + std.debug.print("Initial size: {}\n", .{size_initial}); + + // Add vectors with unique keys + var vec1 = [_]f32{ 1.0, 2.0, 3.0 }; + var vec2 = [_]f32{ 4.0, 5.0, 6.0 }; + var vec3 = [_]f32{ 1.1, 2.1, 3.1 }; + + try index.add(1, &vec1); + try index.add(2, &vec2); + try index.add(3, &vec3); + + const size_after = try index.len(); + std.debug.print("Size after adding: {}\n", .{size_after}); + + // Search for the 2 nearest neighbors + const results = try index.search(&vec1, 2); + defer allocator.free(results); + + std.debug.print("Found {} results\n", .{results.len}); + for (results) |result| { + std.debug.print(" Key: {}, Distance: {d:.6}\n", .{ + result.key, + result.distance, + }); + } +} +``` + +### Custom Distance Metrics + +Configure the index to use different distance metrics: + +```zig +pub fn customMetricExample() !void { + const allocator = std.heap.page_allocator; + + // Create an index with L2 squared distance + var config = usearch.IndexConfig.default(3); + config.metric = .l2sq; // Options: .cosine, .l2sq, .inner_product, etc. + + var index = try usearch.Index.init(allocator, config); + defer index.deinit(); + + var vec1 = [_]f32{ 1.0, 0.0, 0.0 }; + var vec2 = [_]f32{ 0.0, 1.0, 0.0 }; + + try index.add(1, &vec1); + try index.add(2, &vec2); + + const results = try index.search(&vec1, 1); + defer allocator.free(results); + + std.debug.print("Nearest neighbor: key={}\n", .{results[0].key}); +} +``` + +### Remove and Contains Operations + +Manage vectors in the index dynamically: + +```zig +pub fn removeAndContainsExample() !void { + const allocator = std.heap.page_allocator; + + const config = usearch.IndexConfig.default(3); + var index = try usearch.Index.init(allocator, config); + defer index.deinit(); + + var vec = [_]f32{ 1.0, 2.0, 3.0 }; + try index.add(42, &vec); + + // Check if vector exists + const exists_before = try index.contains(42); + std.debug.print("Vector 42 exists: {}\n", .{exists_before}); + + // Remove the vector + try index.remove(42); + + const exists_after = try index.contains(42); + std.debug.print("Vector 42 exists after removal: {}\n", .{exists_after}); +} +``` + +### Direct Distance Calculation + +Calculate distances between vectors without building an index: + +```zig +pub fn distanceExample() !void { + var vec1 = [_]f32{ 1.0, 0.0, 0.0 }; + var vec2 = [_]f32{ 0.0, 1.0, 0.0 }; + + const dist = try usearch.distance(&vec1, &vec2, 3, .l2sq); + std.debug.print("L2 squared distance: {}\n", .{dist}); +} +``` + +### Quantized Vectors (int8) + +Use int8 quantization for memory-efficient storage: + +```zig +pub fn int8VectorsExample() !void { + const allocator = std.heap.page_allocator; + + var config = usearch.IndexConfig.default(4); + config.quantization = .i8; // Options: .f32, .f16, .bf16, .f64, .i8, .b1 + + var index = try usearch.Index.init(allocator, config); + defer index.deinit(); + + var vec1 = [_]i8{ 1, 2, 3, 4 }; + var vec2 = [_]i8{ 5, 6, 7, 8 }; + + try index.addI8(1, &vec1); + try index.addI8(2, &vec2); + + const results = try index.searchI8(&vec1, 1); + defer allocator.free(results); + + std.debug.print("Nearest neighbor: key={}\n", .{results[0].key}); +} +``` + +### Reserve Capacity + +Pre-allocate space for better performance: + +```zig +pub fn reserveCapacityExample() !void { + const allocator = std.heap.page_allocator; + + const config = usearch.IndexConfig.default(3); + var index = try usearch.Index.init(allocator, config); + defer index.deinit(); + + // Reserve space for 5000 vectors upfront + try index.reserve(5000); + + const cap = try index.capacity(); + std.debug.print("Reserved capacity: {}\n", .{cap}); +} +``` + +### Query Index Properties + +Inspect index configuration and statistics: + +```zig +pub fn queryPropertiesExample() !void { + const allocator = std.heap.page_allocator; + + const config = usearch.IndexConfig.default(128); + var index = try usearch.Index.init(allocator, config); + defer index.deinit(); + + const dims = try index.dimensions(); + std.debug.print("Dimensions: {}\n", .{dims}); + + const cap = try index.capacity(); + std.debug.print("Capacity: {}\n", .{cap}); + + const mem = try index.memoryUsage(); + std.debug.print("Memory usage: {} bytes\n", .{mem}); +} +``` + +### Working with Larger Datasets + +Efficiently handle large collections of vectors: + +```zig +pub fn largerDatasetExample() !void { + const allocator = std.heap.page_allocator; + + var config = usearch.IndexConfig.default(8); + config.initial_capacity = 100; + + var index = try usearch.Index.init(allocator, config); + defer index.deinit(); + + // Add 50 vectors + var i: usize = 0; + while (i < 50) : (i += 1) { + var vec = [_]f32{ + @as(f32, @floatFromInt(i)), + @as(f32, @floatFromInt(i + 1)), + @as(f32, @floatFromInt(i + 2)), + @as(f32, @floatFromInt(i + 3)), + @as(f32, @floatFromInt(i + 4)), + @as(f32, @floatFromInt(i + 5)), + @as(f32, @floatFromInt(i + 6)), + @as(f32, @floatFromInt(i + 7)), + }; + try index.add(i, &vec); + } + + const final_size = try index.len(); + std.debug.print("Final index size: {}\n", .{final_size}); + + // Search for similar vectors + var query = [_]f32{ 10, 11, 12, 13, 14, 15, 16, 17 }; + const results = try index.search(&query, 5); + defer allocator.free(results); + + std.debug.print("Top 5 results:\n", .{}); + for (results) |result| { + std.debug.print(" Key: {}, Distance: {d:.6}\n", .{ + result.key, + result.distance, + }); + } +} +``` + +### Persistence: Save and Load + +Save indices to disk and load them later: + +```zig +pub fn persistenceExample() !void { + const allocator = std.heap.page_allocator; + + // Create and populate an index + var config = usearch.IndexConfig.default(3); + var index = try usearch.Index.init(allocator, config); + defer index.deinit(); + + var vec1 = [_]f32{ 1.0, 2.0, 3.0 }; + var vec2 = [_]f32{ 4.0, 5.0, 6.0 }; + + try index.add(1, &vec1); + try index.add(2, &vec2); + + // Save to disk + try index.save("my_index.usearch"); + std.debug.print("Index saved\n", .{}); + + // Load from disk + var loaded_index = try usearch.Index.init(allocator, config); + defer loaded_index.deinit(); + + try loaded_index.load("my_index.usearch"); + + const size = try loaded_index.len(); + std.debug.print("Loaded index with {} vectors\n", .{size}); +} +``` + +### Advanced Configuration + +Fine-tune performance parameters: + +```zig +pub fn advancedConfigExample() !void { + const allocator = std.heap.page_allocator; + + var config = usearch.IndexConfig{ + .dimensions = 128, + .metric = .cosine, + .quantization = .f32, + .connectivity = 16, // Higher = more accurate but slower + .expansion_add = 128, // Controls indexing quality + .expansion_search = 64, // Controls search quality + .multi = false, // Allow multiple vectors per key + .initial_capacity = 10000, + }; + + var index = try usearch.Index.init(allocator, config); + defer index.deinit(); + + // Dynamically adjust parameters + try index.setExpansionSearch(128); + try index.setThreadsSearch(4); + + std.debug.print("Advanced index configured\n", .{}); +} +``` + +## API Reference + +### Types + +#### `Key` + +```zig +pub const Key = u64; +``` + +Unique identifier for vectors in the index. + +#### `Metric` + +```zig +pub const Metric = enum(u8) { + inner_product, + cosine, + l2sq, + haversine, + divergence, + pearson, + hamming, + tanimoto, + sorensen, +}; +``` + +Distance metrics for comparing vectors. + +#### `Quantization` + +```zig +pub const Quantization = enum(u8) { + f32, + bf16, + f16, + f64, + i8, + b1, +}; +``` + +Scalar quantization types for vector storage. + +#### `IndexConfig` + +```zig +pub const IndexConfig = struct { + quantization: Quantization = .f32, + metric: Metric = .cosine, + dimensions: usize, + connectivity: usize = 0, + expansion_add: usize = 0, + expansion_search: usize = 0, + multi: bool = false, + initial_capacity: usize = 1000, +}; +``` + +Configuration options for creating an index. + +#### `SearchResult` + +```zig +pub const SearchResult = struct { + key: Key, + distance: f32, +}; +``` + +Result containing a key and its distance from the query. + +### Index Methods + +#### Core Operations + +- **`init(allocator, config) !Index`** - Create a new index +- **`deinit()`** - Free index resources +- **`add(key, vector) !void`** - Add a float32 vector +- **`addI8(key, vector) !void`** - Add an int8 vector +- **`search(query, k) ![]SearchResult`** - Find k nearest neighbors (returns owned slice) +- **`searchI8(query, k) ![]SearchResult`** - Search using int8 query +- **`remove(key) !void`** - Remove a vector by key +- **`contains(key) !bool`** - Check if a key exists +- **`get(key, max_count) !?[]f32`** - Retrieve vector by key (returns owned slice or null) +- **`rename(from, to) !void`** - Rename a vector key + +#### Queries + +- **`len() !usize`** - Get number of vectors in index +- **`capacity() !usize`** - Get current capacity +- **`dimensions() !usize`** - Get vector dimensionality +- **`memoryUsage() !usize`** - Get memory usage in bytes +- **`reserve(capacity) !void`** - Reserve space for vectors + +#### Persistence + +- **`save(path) !void`** - Save index to file +- **`load(path) !void`** - Load index from file +- **`view(path) !void`** - Memory-map index from file (zero-copy) + +#### Configuration + +- **`setExpansionAdd(expansion) !void`** - Set expansion factor for adding +- **`setExpansionSearch(expansion) !void`** - Set expansion factor for search +- **`setThreadsAdd(threads) !void`** - Set number of threads for indexing +- **`setThreadsSearch(threads) !void`** - Set number of threads for search + +### Utility Functions + +- **`distance(vec1, vec2, dimensions, metric) !f32`** - Compute distance between float32 vectors +- **`distanceI8(vec1, vec2, dimensions, metric) !f32`** - Compute distance between int8 vectors +- **`loadMetadata(allocator, path) !IndexConfig`** - Load metadata from a saved index file + +## Performance Tips + +1. **Reserve capacity upfront** - Use `reserve()` or set `initial_capacity` to avoid reallocations +2. **Choose the right metric** - Cosine similarity is good for normalized vectors, L2 for geometric distances +3. **Tune connectivity** - Higher connectivity (16-32) improves accuracy but increases memory usage +4. **Use quantization** - int8 or f16 can significantly reduce memory usage with minimal accuracy loss +5. **Parallel operations** - Use `setThreadsAdd()` and `setThreadsSearch()` for multi-threaded performance +6. **Adjust expansion factors** - Higher values improve accuracy but slow down operations + +## Building from Source + +```bash +git clone https://github.com/pacifio/usearch-zig +cd usearch-zig +zig build test +``` + +## Requirements + +- Zig 0.15.1 or later +- C++ compiler with C++17 support diff --git a/zig/build.zig b/zig/build.zig new file mode 100644 index 000000000..c5ec07939 --- /dev/null +++ b/zig/build.zig @@ -0,0 +1,74 @@ +const std = @import("std"); + +pub fn build(b: *std.Build) void { + const target = b.standardTargetOptions(.{}); + const optimize = b.standardOptimizeOption(.{}); + + const usearch_mod = b.addModule("usearch_zig", .{ + .root_source_file = b.path("src/root.zig"), + .target = target, + .optimize = optimize, + }); + + const usearch_cpp_mod = b.createModule(.{ + .target = target, + .optimize = optimize, + }); + + usearch_cpp_mod.addCSourceFile(.{ + .file = b.path("usearch/include/lib.cpp"), + .flags = &.{ + "-std=c++17", + "-fno-exceptions", + "-fno-rtti", + }, + }); + + usearch_cpp_mod.addIncludePath(b.path("usearch/include")); + usearch_cpp_mod.link_libcpp = true; + + const usearch_lib = b.addLibrary(.{ + .name = "usearch", + .root_module = usearch_cpp_mod, + .linkage = .static, + }); + + const exe = b.addExecutable(.{ + .name = "usearch_zig", + .root_module = b.createModule(.{ + .root_source_file = b.path("src/main.zig"), + .target = target, + .optimize = optimize, + .imports = &.{ + .{ .name = "usearch_zig", .module = usearch_mod }, + }, + }), + }); + + exe.root_module.link_libc = true; + exe.root_module.link_libcpp = true; + exe.root_module.linkLibrary(usearch_lib); + exe.root_module.addIncludePath(b.path("usearch/include")); + + b.installArtifact(exe); + + const run_step = b.step("run", "Run the usearch_zig executable"); + const run_cmd = b.addRunArtifact(exe); + run_step.dependOn(&run_cmd.step); + run_cmd.step.dependOn(b.getInstallStep()); + + if (b.args) |args| run_cmd.addArgs(args); + + const mod_tests = b.addTest(.{ + .root_module = usearch_mod, + }); + + mod_tests.root_module.addIncludePath(b.path("usearch/include")); + mod_tests.root_module.link_libcpp = true; + mod_tests.root_module.linkLibrary(usearch_lib); + + const run_mod_tests = b.addRunArtifact(mod_tests); + + const test_step = b.step("test", "Run tests for usearch_zig"); + test_step.dependOn(&run_mod_tests.step); +} diff --git a/zig/build.zig.zon b/zig/build.zig.zon new file mode 100644 index 000000000..24b7e8d8d --- /dev/null +++ b/zig/build.zig.zon @@ -0,0 +1,13 @@ +.{ + .name = .usearch_zig, + .version = "0.0.1", + .fingerprint = 0x20d43150bf11d52, + .minimum_zig_version = "0.15.1", + .dependencies = .{}, + .paths = .{ + "build.zig", + "build.zig.zon", + "src", + "usearch", + }, +} diff --git a/zig/src/root.zig b/zig/src/root.zig new file mode 100644 index 000000000..06b76094d --- /dev/null +++ b/zig/src/root.zig @@ -0,0 +1,757 @@ +const std = @import("std"); +const c = @cImport({ + @cInclude("usearch.h"); +}); + +/// Unique identifier for vectors in the index +pub const Key = u64; + +/// Distance metric for comparing vectors +pub const Metric = enum(u8) { + inner_product = 0, + cosine = 1, + l2sq = 2, + haversine = 3, + divergence = 4, + pearson = 5, + hamming = 6, + tanimoto = 7, + sorensen = 8, + + pub fn toCValue(self: Metric) c.usearch_metric_kind_t { + return switch (self) { + .l2sq => c.usearch_metric_l2sq_k, + .inner_product => c.usearch_metric_ip_k, + .cosine => c.usearch_metric_cos_k, + .haversine => c.usearch_metric_haversine_k, + .divergence => c.usearch_metric_divergence_k, + .pearson => c.usearch_metric_pearson_k, + .hamming => c.usearch_metric_hamming_k, + .tanimoto => c.usearch_metric_tanimoto_k, + .sorensen => c.usearch_metric_sorensen_k, + }; + } + + pub fn toString(self: Metric) []const u8 { + return switch (self) { + .l2sq => "l2sq", + .inner_product => "ip", + .cosine => "cos", + .haversine => "haversine", + .divergence => "divergence", + .pearson => "pearson", + .hamming => "hamming", + .tanimoto => "tanimoto", + .sorensen => "sorensen", + }; + } +}; + +/// Scalar quantization type for vector storage +pub const Quantization = enum(u8) { + f32 = 0, + bf16 = 1, + f16 = 2, + f64 = 3, + i8 = 4, + b1 = 5, + + pub fn toCValue(self: Quantization) c.usearch_scalar_kind_t { + return switch (self) { + .f16 => c.usearch_scalar_f16_k, + .f32 => c.usearch_scalar_f32_k, + .f64 => c.usearch_scalar_f64_k, + .i8 => c.usearch_scalar_i8_k, + .b1 => c.usearch_scalar_b1_k, + .bf16 => c.usearch_scalar_bf16_k, + }; + } + + pub fn toString(self: Quantization) []const u8 { + return switch (self) { + .bf16 => "BF16", + .f16 => "F16", + .f32 => "F32", + .f64 => "F64", + .i8 => "I8", + .b1 => "B1", + }; + } +}; + +/// Configuration for creating a new USearch index +pub const IndexConfig = struct { + quantization: Quantization = .f32, + metric: Metric = .cosine, + dimensions: usize, + connectivity: usize = 0, + expansion_add: usize = 0, + expansion_search: usize = 0, + multi: bool = false, + initial_capacity: usize = 1000, + + /// Create default configuration for given dimensions + pub fn default(dimensions: usize) IndexConfig { + return .{ + .dimensions = dimensions, + .metric = .cosine, + .quantization = .f32, + }; + } +}; + +/// Errors that can occur during USearch operations +pub const Error = error{ + IndexUninitialized, + DimensionMismatch, + EmptyVector, + BufferTooSmall, + InvalidPath, + UsearchError, + OutOfMemory, +}; + +/// Search result containing a key and its distance +pub const SearchResult = struct { + key: Key, + distance: f32, +}; + +/// Main USearch index for approximate nearest neighbor search +pub const Index = struct { + handle: c.usearch_index_t, + config: IndexConfig, + allocator: std.mem.Allocator, + + /// Create a new USearch index with the given configuration + pub fn init(allocator: std.mem.Allocator, config: IndexConfig) Error!Index { + if (config.dimensions == 0) { + return Error.DimensionMismatch; + } + + var options = std.mem.zeroes(c.usearch_init_options_t); + options.metric_kind = config.metric.toCValue(); + options.quantization = config.quantization.toCValue(); + options.dimensions = config.dimensions; + options.connectivity = config.connectivity; + options.expansion_add = config.expansion_add; + options.expansion_search = config.expansion_search; + options.multi = config.multi; + + var error_msg: c.usearch_error_t = null; + const handle = c.usearch_init(&options, &error_msg); + + if (error_msg != null) { + return Error.UsearchError; + } + + if (handle == null) { + return Error.UsearchError; + } + + const index = Index{ + .handle = handle, + .config = config, + .allocator = allocator, + }; + + var reserve_err: c.usearch_error_t = null; + c.usearch_reserve(index.handle, config.initial_capacity, &reserve_err); + + if (reserve_err != null) { + c.usearch_free(index.handle, &reserve_err); + return Error.UsearchError; + } + + return index; + } + + /// Free resources associated with the index + pub fn deinit(self: *Index) void { + if (self.handle) |handle| { + var error_msg: c.usearch_error_t = null; + c.usearch_free(handle, &error_msg); + if (error_msg != null) { + std.debug.print("Warning: error during usearch_free: {s}\n", .{error_msg}); + } + } + self.handle = null; + } + + /// Get the number of vectors in the index + pub fn len(self: *const Index) Error!usize { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + const size = c.usearch_size(self.handle, &error_msg); + + if (error_msg != null) return Error.UsearchError; + return size; + } + + /// Get the capacity of the index + pub fn capacity(self: *const Index) Error!usize { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + const cap = c.usearch_capacity(self.handle, &error_msg); + + if (error_msg != null) return Error.UsearchError; + return cap; + } + + /// Get the dimensions of vectors in the index + pub fn dimensions(self: *const Index) Error!usize { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + const dims = c.usearch_dimensions(self.handle, &error_msg); + + if (error_msg != null) return Error.UsearchError; + return dims; + } + + /// Get memory usage in bytes + pub fn memoryUsage(self: *const Index) Error!usize { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + const usage = c.usearch_memory_usage(self.handle, &error_msg); + + if (error_msg != null) return Error.UsearchError; + return usage; + } + + /// Reserve capacity for a number of vectors + pub fn reserve(self: *Index, cap: usize) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + c.usearch_reserve(self.handle, cap, &error_msg); + + if (error_msg != null) return Error.UsearchError; + } + + /// Add a float32 vector to the index + pub fn add(self: *Index, key: Key, vector: []const f32) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + if (vector.len == 0) return Error.EmptyVector; + if (vector.len != self.config.dimensions) return Error.DimensionMismatch; + + const vec_ptr: ?*const anyopaque = @ptrCast(vector.ptr); + + var error_msg: c.usearch_error_t = null; + c.usearch_add( + self.handle, + @as(c.usearch_key_t, key), + vec_ptr, + c.usearch_scalar_f32_k, + @as([*c]c.usearch_error_t, @ptrCast(&error_msg)), + ); + + if (error_msg != null) return Error.UsearchError; + } + + /// Add an int8 vector to the index + pub fn addI8(self: *Index, key: Key, vector: []const i8) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + if (vector.len == 0) return Error.EmptyVector; + if (vector.len != self.config.dimensions) return Error.DimensionMismatch; + + var error_msg: c.usearch_error_t = null; + c.usearch_add( + self.handle, + key, + @as(?*const anyopaque, @ptrCast(vector.ptr)), + c.usearch_scalar_i8_k, + &error_msg, + ); + + if (error_msg != null) return Error.UsearchError; + } + + /// Remove a vector by key + pub fn remove(self: *Index, key: Key) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + _ = c.usearch_remove(self.handle, key, &error_msg); + + if (error_msg != null) return Error.UsearchError; + } + + /// Check if a key exists in the index + pub fn contains(self: *const Index, key: Key) Error!bool { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + const found = c.usearch_contains(self.handle, key, &error_msg); + + if (error_msg != null) return Error.UsearchError; + return found; + } + + /// Get a vector by key. Returns owned slice, caller must free. + pub fn get(self: *const Index, key: Key, max_count: usize) Error!?[]f32 { + if (self.handle == null) return Error.IndexUninitialized; + if (max_count == 0) return null; + + const buffer = try self.allocator.alloc(f32, self.config.dimensions * max_count); + errdefer self.allocator.free(buffer); + + var error_msg: c.usearch_error_t = null; + const found = c.usearch_get( + self.handle, + key, + max_count, + @as(?*anyopaque, @ptrCast(buffer.ptr)), + c.usearch_scalar_f32_k, + &error_msg, + ); + + if (error_msg != null) { + self.allocator.free(buffer); + return Error.UsearchError; + } + + if (found == 0) { + self.allocator.free(buffer); + return null; + } + + return buffer; + } + + /// Rename a vector from one key to another + pub fn rename(self: *Index, from: Key, to: Key) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + _ = c.usearch_rename(self.handle, from, to, &error_msg); + + if (error_msg != null) return Error.UsearchError; + } + + /// Search for nearest neighbors. Returns owned slice, caller must free. + pub fn search( + self: *const Index, + query: []const f32, + limit: usize, + ) Error![]SearchResult { + if (self.handle == null) return Error.IndexUninitialized; + if (query.len == 0) return Error.EmptyVector; + if (query.len != self.config.dimensions) return Error.DimensionMismatch; + if (limit == 0) return &[_]SearchResult{}; + + const keys = try self.allocator.alloc(Key, limit); + defer self.allocator.free(keys); + + const distances = try self.allocator.alloc(f32, limit); + defer self.allocator.free(distances); + + var error_msg: c.usearch_error_t = null; + const result_count = c.usearch_search( + self.handle, + @as(?*const anyopaque, @ptrCast(query.ptr)), + c.usearch_scalar_f32_k, + limit, + @as([*c]c.usearch_key_t, @ptrCast(keys.ptr)), + @as([*c]c.usearch_distance_t, @ptrCast(distances.ptr)), + &error_msg, + ); + + if (error_msg != null) return Error.UsearchError; + + const results = try self.allocator.alloc(SearchResult, result_count); + for (0..result_count) |i| { + results[i] = .{ + .key = keys[i], + .distance = distances[i], + }; + } + + return results; + } + + /// Search using int8 query vector. Returns owned slice, caller must free. + pub fn searchI8( + self: *const Index, + query: []const i8, + limit: usize, + ) Error![]SearchResult { + if (self.handle == null) return Error.IndexUninitialized; + if (query.len == 0) return Error.EmptyVector; + if (query.len != self.config.dimensions) return Error.DimensionMismatch; + if (limit == 0) return &[_]SearchResult{}; + + const keys = try self.allocator.alloc(Key, limit); + defer self.allocator.free(keys); + + const distances = try self.allocator.alloc(f32, limit); + defer self.allocator.free(distances); + + var error_msg: c.usearch_error_t = null; + const result_count = c.usearch_search( + self.handle, + @as(?*const anyopaque, @ptrCast(query.ptr)), + c.usearch_scalar_i8_k, + limit, + @as([*c]c.usearch_key_t, @ptrCast(keys.ptr)), + @as([*c]c.usearch_distance_t, @ptrCast(distances.ptr)), + &error_msg, + ); + + if (error_msg != null) return Error.UsearchError; + + const results = try self.allocator.alloc(SearchResult, result_count); + for (0..result_count) |i| { + results[i] = .{ + .key = keys[i], + .distance = distances[i], + }; + } + + return results; + } + + /// Save index to file + pub fn save(self: *const Index, path: []const u8) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + if (path.len == 0) return Error.InvalidPath; + + const c_path = try self.allocator.dupeZ(u8, path); + defer self.allocator.free(c_path); + + var error_msg: c.usearch_error_t = null; + c.usearch_save(self.handle, c_path.ptr, &error_msg); + + if (error_msg != null) return Error.UsearchError; + } + + /// Load index from file + pub fn load(self: *Index, path: []const u8) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + if (path.len == 0) return Error.InvalidPath; + + const c_path = try self.allocator.dupeZ(u8, path); + defer self.allocator.free(c_path); + + var error_msg: c.usearch_error_t = null; + c.usearch_load(self.handle, c_path.ptr, &error_msg); + + if (error_msg != null) return Error.UsearchError; + } + + /// View index from file without loading into memory + pub fn view(self: *Index, path: []const u8) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + if (path.len == 0) return Error.InvalidPath; + + const c_path = try self.allocator.dupeZ(u8, path); + defer self.allocator.free(c_path); + + var error_msg: c.usearch_error_t = null; + c.usearch_view(self.handle, c_path.ptr, &error_msg); + + if (error_msg != null) return Error.UsearchError; + } + + /// Change the expansion factor for adding vectors + pub fn setExpansionAdd(self: *Index, expansion: usize) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + c.usearch_change_expansion_add(self.handle, expansion, &error_msg); + + if (error_msg != null) return Error.UsearchError; + } + + /// Change the expansion factor for search + pub fn setExpansionSearch(self: *Index, expansion: usize) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + c.usearch_change_expansion_search(self.handle, expansion, &error_msg); + + if (error_msg != null) return Error.UsearchError; + } + + /// Set number of threads for adding vectors + pub fn setThreadsAdd(self: *Index, threads: usize) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + c.usearch_change_threads_add(self.handle, threads, &error_msg); + + if (error_msg != null) return Error.UsearchError; + } + + /// Set number of threads for search + pub fn setThreadsSearch(self: *Index, threads: usize) Error!void { + if (self.handle == null) return Error.IndexUninitialized; + + var error_msg: c.usearch_error_t = null; + c.usearch_change_threads_search(self.handle, threads, &error_msg); + + if (error_msg != null) return Error.UsearchError; + } +}; + +/// Compute distance between two float32 vectors +pub fn distance( + vec1: []const f32, + vec2: []const f32, + dimensions: usize, + metric: Metric, +) Error!f32 { + if (vec1.len == 0 or vec2.len == 0) return Error.EmptyVector; + if (vec1.len < dimensions or vec2.len < dimensions) return Error.DimensionMismatch; + + var error_msg: c.usearch_error_t = null; + const dist = c.usearch_distance( + @as(?*const anyopaque, @ptrCast(vec1.ptr)), + @as(?*const anyopaque, @ptrCast(vec2.ptr)), + c.usearch_scalar_f32_k, + dimensions, + metric.toCValue(), + &error_msg, + ); + + if (error_msg != null) return Error.UsearchError; + return dist; +} + +/// Compute distance between two int8 vectors +pub fn distanceI8( + vec1: []const i8, + vec2: []const i8, + dimensions: usize, + metric: Metric, +) Error!f32 { + if (vec1.len == 0 or vec2.len == 0) return Error.EmptyVector; + if (vec1.len < dimensions or vec2.len < dimensions) return Error.DimensionMismatch; + + var error_msg: c.usearch_error_t = null; + const dist = c.usearch_distance( + @as(?*const anyopaque, @ptrCast(vec1.ptr)), + @as(?*const anyopaque, @ptrCast(vec2.ptr)), + c.usearch_scalar_i8_k, + dimensions, + metric.toCValue(), + &error_msg, + ); + + if (error_msg != null) return Error.UsearchError; + return dist; +} + +/// Load metadata from a saved index file +pub fn loadMetadata(allocator: std.mem.Allocator, path: []const u8) Error!IndexConfig { + if (path.len == 0) return Error.InvalidPath; + + const c_path = try allocator.dupeZ(u8, path); + defer allocator.free(c_path); + + var options = std.mem.zeroes(c.usearch_init_options_t); + var error_msg: c.usearch_error_t = null; + + c.usearch_metadata(c_path.ptr, &options, &error_msg); + if (error_msg != null) return Error.UsearchError; + + var config = IndexConfig{ + .dimensions = options.dimensions, + .connectivity = options.connectivity, + .expansion_add = options.expansion_add, + .expansion_search = options.expansion_search, + .multi = options.multi, + .metric = .cosine, + .quantization = .f32, + }; + + config.metric = switch (options.metric_kind) { + c.usearch_metric_l2sq_k => .l2sq, + c.usearch_metric_ip_k => .inner_product, + c.usearch_metric_cos_k => .cosine, + c.usearch_metric_haversine_k => .haversine, + c.usearch_metric_divergence_k => .divergence, + c.usearch_metric_pearson_k => .pearson, + c.usearch_metric_hamming_k => .hamming, + c.usearch_metric_tanimoto_k => .tanimoto, + c.usearch_metric_sorensen_k => .sorensen, + else => .cosine, + }; + + config.quantization = switch (options.quantization) { + c.usearch_scalar_f16_k => .f16, + c.usearch_scalar_f32_k => .f32, + c.usearch_scalar_f64_k => .f64, + c.usearch_scalar_i8_k => .i8, + c.usearch_scalar_b1_k => .b1, + c.usearch_scalar_bf16_k => .bf16, + else => .f32, + }; + + return config; +} + +test "basic index operations" { + const allocator = std.testing.allocator; + + const config = IndexConfig.default(3); + var index = try Index.init(allocator, config); + defer index.deinit(); + + const size_initial = try index.len(); + try std.testing.expectEqual(@as(usize, 0), size_initial); + + var vec1 = [_]f32{ 1.0, 2.0, 3.0 }; + var vec2 = [_]f32{ 4.0, 5.0, 6.0 }; + var vec3 = [_]f32{ 1.1, 2.1, 3.1 }; + + try index.add(1, &vec1); + try index.add(2, &vec2); + try index.add(3, &vec3); + + const size_after = try index.len(); + try std.testing.expectEqual(@as(usize, 3), size_after); + + const results = try index.search(&vec1, 2); + defer allocator.free(results); + + try std.testing.expect(results.len > 0); + try std.testing.expectEqual(@as(Key, 1), results[0].key); + try std.testing.expect(results[0].distance < 0.01); + + if (results.len > 1) { + try std.testing.expectEqual(@as(Key, 3), results[1].key); + } +} + +test "index with custom metric" { + const allocator = std.testing.allocator; + + var config = IndexConfig.default(3); + config.metric = .l2sq; + + var index = try Index.init(allocator, config); + defer index.deinit(); + + var vec1 = [_]f32{ 1.0, 0.0, 0.0 }; + var vec2 = [_]f32{ 0.0, 1.0, 0.0 }; + + try index.add(1, &vec1); + try index.add(2, &vec2); + + const results = try index.search(&vec1, 1); + defer allocator.free(results); + + try std.testing.expectEqual(@as(Key, 1), results[0].key); +} + +test "remove and contains" { + const allocator = std.testing.allocator; + + const config = IndexConfig.default(3); + var index = try Index.init(allocator, config); + defer index.deinit(); + + var vec = [_]f32{ 1.0, 2.0, 3.0 }; + try index.add(42, &vec); + + const exists_before = try index.contains(42); + try std.testing.expect(exists_before); + + try index.remove(42); + + const exists_after = try index.contains(42); + try std.testing.expect(!exists_after); +} + +test "distance calculation" { + var vec1 = [_]f32{ 1.0, 0.0, 0.0 }; + var vec2 = [_]f32{ 0.0, 1.0, 0.0 }; + + const dist = try distance(&vec1, &vec2, 3, .l2sq); + try std.testing.expectEqual(@as(f32, 2.0), dist); +} + +test "int8 vectors" { + const allocator = std.testing.allocator; + + var config = IndexConfig.default(4); + config.quantization = .i8; + var index = try Index.init(allocator, config); + defer index.deinit(); + + var vec1 = [_]i8{ 1, 2, 3, 4 }; + var vec2 = [_]i8{ 5, 6, 7, 8 }; + + try index.addI8(1, &vec1); + try index.addI8(2, &vec2); + + const results = try index.searchI8(&vec1, 1); + defer allocator.free(results); + + try std.testing.expectEqual(@as(Key, 1), results[0].key); +} + +test "reserve capacity" { + const allocator = std.testing.allocator; + + const config = IndexConfig.default(3); + var index = try Index.init(allocator, config); + defer index.deinit(); + + try index.reserve(5000); + + const cap = try index.capacity(); + try std.testing.expect(cap >= 5000); +} + +test "dimensions and capacity queries" { + const allocator = std.testing.allocator; + + const config = IndexConfig.default(128); + var index = try Index.init(allocator, config); + defer index.deinit(); + + const dims = try index.dimensions(); + try std.testing.expectEqual(@as(usize, 128), dims); + + const cap = try index.capacity(); + try std.testing.expect(cap > 0); +} + +test "index with larger dataset" { + const allocator = std.testing.allocator; + + var config = IndexConfig.default(8); + config.initial_capacity = 100; + var index = try Index.init(allocator, config); + defer index.deinit(); + + var i: usize = 0; + while (i < 50) : (i += 1) { + var vec = [_]f32{ + @as(f32, @floatFromInt(i)), + @as(f32, @floatFromInt(i + 1)), + @as(f32, @floatFromInt(i + 2)), + @as(f32, @floatFromInt(i + 3)), + @as(f32, @floatFromInt(i + 4)), + @as(f32, @floatFromInt(i + 5)), + @as(f32, @floatFromInt(i + 6)), + @as(f32, @floatFromInt(i + 7)), + }; + try index.add(i, &vec); + } + + const final_size = try index.len(); + try std.testing.expectEqual(@as(usize, 50), final_size); + + var query = [_]f32{ 10, 11, 12, 13, 14, 15, 16, 17 }; + const results = try index.search(&query, 5); + defer allocator.free(results); + + try std.testing.expect(results.len > 0); + try std.testing.expectEqual(@as(Key, 10), results[0].key); +} diff --git a/zig/usearch/include/index.hpp b/zig/usearch/include/index.hpp new file mode 100644 index 000000000..73c70755b --- /dev/null +++ b/zig/usearch/include/index.hpp @@ -0,0 +1,4548 @@ +/** + * @file index.hpp + * @author Ash Vardanian + * @brief Single-header Vector Search engine. + * @date April 26, 2023 + */ +#ifndef UNUM_USEARCH_HPP +#define UNUM_USEARCH_HPP + +#define USEARCH_VERSION_MAJOR 2 +#define USEARCH_VERSION_MINOR 21 +#define USEARCH_VERSION_PATCH 1 + +// Inferring C++ version +// https://stackoverflow.com/a/61552074 +#if ((defined(_MSVC_LANG) && _MSVC_LANG >= 201703L) || __cplusplus >= 201703L) +#define USEARCH_DEFINED_CPP17 +#endif +#if ((defined(_MSVC_LANG) && _MSVC_LANG >= 202002L) || __cplusplus >= 202002L) +#define USEARCH_DEFINED_CPP20 +#endif + +// Inferring target OS: Windows, MacOS, or Linux +#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__NT__) +#define USEARCH_DEFINED_WINDOWS +#elif defined(__APPLE__) && defined(__MACH__) +#define USEARCH_DEFINED_APPLE +#elif defined(__linux__) +#define USEARCH_DEFINED_LINUX +#if defined(__ANDROID_API__) +#define USEARCH_DEFINED_ANDROID +#endif +#endif + +// Inferring the compiler: Clang vs GCC +#if defined(__clang__) +#define USEARCH_DEFINED_CLANG +#elif defined(__GNUC__) +#define USEARCH_DEFINED_GCC +#endif + +// The `#pragma region` and `#pragma endregion` are not supported by GCC 12 and older. +// But they are supported by GCC 13, all recent Clang versions, and MSVC. +#if defined(__GNUC__) && ((__GNUC__ > 13) || (__GNUC__ == 13 && __GNUC_MINOR__ >= 0)) +#define USEARCH_USE_PRAGMA_REGION +#elif defined(__clang__) || defined(_MSC_VER) +#define USEARCH_USE_PRAGMA_REGION +#endif + +// Inferring hardware architecture: x86 vs Arm +#if defined(__x86_64__) +#define USEARCH_DEFINED_X86 +#elif defined(__aarch64__) +#define USEARCH_DEFINED_ARM +#endif + +// Inferring hardware bitness: 32 vs 64 +// Using compiler predefined macros for is technically safer than including `` and +// using the commonly advised `UINTPTR_MAX` trick, as that constant is optional in standard C/C++. +// https://stackoverflow.com/a/5273354 +// https://en.cppreference.com/w/cpp/types/integer.html +#if defined(_WIN64) || defined(__LP64__) || defined(__x86_64__) || defined(__aarch64__) || defined(__powerpc64__) +#define USEARCH_64BIT_ENV +#else +#define USEARCH_32BIT_ENV +#endif + +#if !defined(USEARCH_USE_OPENMP) +#define USEARCH_USE_OPENMP 0 +#endif + +// OS-specific includes +#if defined(USEARCH_DEFINED_WINDOWS) +#define _USE_MATH_DEFINES +#define NOMINMAX +#include +#include // `fstat` for file size +#undef NOMINMAX +#undef _USE_MATH_DEFINES +#else +#include // `fallocate` +#include // `posix_memalign` +#include // `mmap` +#include // `fstat` for file size +#include // `open`, `close` +#endif + +// STL includes +#include // `std::sort_heap` +#include // `std::atomic` +#include // `std::bitset` +#include // `CHAR_BIT` +#include // `std::sqrt` +#include // `std::memset` +#include // `std::reverse_iterator` +#include // `std::unique_lock` - replacement candidate +#include // `std::default_random_engine` - replacement candidate +#include // `std::runtime_exception` +#include // `std::thread` +#include // `std::pair` + +// Helper macros for concatenation and stringification +#define usearch_concat_helper_m(a, b) a##b +#define usearch_concat_m(a, b) usearch_concat_helper_m(a, b) +#define usearch_stringify_helper_m(x) #x +#define usearch_stringify_m(x) usearch_stringify_helper_m(x) + +// Prefetching +#if defined(USEARCH_DEFINED_GCC) +// https://gcc.gnu.org/onlinedocs/gcc/Other-Builtins.html +// Zero means we are only going to read from that memory. +// Three means high temporal locality and suggests to keep +// the data in all layers of cache. +#define usearch_prefetch_m(ptr) __builtin_prefetch((void*)(ptr), 0, 3) +#elif defined(USEARCH_DEFINED_X86) +#define usearch_prefetch_m(ptr) _mm_prefetch((void*)(ptr), _MM_HINT_T0) +#else +#define usearch_prefetch_m(ptr) +#endif + +// Function profiling +#if defined(usearch_defined_x86) +#define usearch_profiled_m __attribute__((noinline)) +#define usearch_profile_name_m(name) \ + __asm__ volatile(".globl " usearch_stringify_m(usearch_concat_m(name, __COUNTER__)) "\n" usearch_stringify_m( \ + usearch_concat_m(name, __COUNTER__)) ":") +#elif defined(usearch_defined_arm) +#define usearch_profiled_m __attribute__((noinline)) +#define usearch_profile_name_m(name) \ + __asm__ volatile(".global " usearch_stringify_m(usearch_concat_m(name, __COUNTER__)) "\n" usearch_stringify_m( \ + usearch_concat_m(name, __COUNTER__)) ":") +#else +#define usearch_profiled_m +#define usearch_profile_name_m(name) +#endif + +// Alignment +#if defined(USEARCH_DEFINED_WINDOWS) +#define usearch_pack_m +#define usearch_align_m __declspec(align(64)) +#else +#define usearch_pack_m __attribute__((packed)) +#define usearch_align_m __attribute__((aligned(64))) +#endif + +// Debugging +#if defined(NDEBUG) +#define usearch_assert_m(must_be_true, message) +#define usearch_noexcept_m noexcept +#else +#define usearch_assert_m(must_be_true, message) \ + if (!(must_be_true)) { \ + usearch_raise_runtime_error(message); \ + } +#define usearch_noexcept_m +#endif + +extern "C" { +/// @brief Helper function to simplify debugging - trace just one symbol - `usearch_raise_runtime_error`. +/// Assuming the `extern C` block, the name won't be mangled. +inline static void usearch_raise_runtime_error(char const* message) { + // On Windows we compile with `/EHc` flag, which specifies that functions + // with C linkage do not throw C++ exceptions. +#if !defined(__cpp_exceptions) || defined(USEARCH_DEFINED_WINDOWS) + std::terminate(); +#else + throw std::runtime_error(message); +#endif +} +} + +namespace unum { +namespace usearch { + +using byte_t = char; + +template std::size_t divide_round_up(std::size_t num) noexcept { + return (num + multiple_ak - 1) / multiple_ak; +} + +inline std::size_t divide_round_up(std::size_t num, std::size_t denominator) noexcept { + return (num + denominator - 1) / denominator; +} + +inline std::size_t ceil2(std::size_t v) noexcept { + v--; + v |= v >> 1; + v |= v >> 2; + v |= v >> 4; + v |= v >> 8; + v |= v >> 16; +#ifdef USEARCH_64BIT_ENV + v |= v >> 32; +#endif + v++; + return v; +} + +/// @brief Simply dereferencing misaligned pointers can be dangerous. +template void misaligned_store(void* ptr, at v) noexcept { + static_assert(!std::is_reference::value, "Can't store a reference"); + std::memcpy(ptr, &v, sizeof(at)); +} + +/// @brief Simply dereferencing misaligned pointers can be dangerous. +template at misaligned_load(void const* ptr) noexcept { + static_assert(!std::is_reference::value, "Can't load a reference"); + at v; + std::memcpy(&v, ptr, sizeof(at)); + return v; +} + +/// @brief The `std::exchange` alternative for C++11. +template at exchange(at& obj, other_at&& new_value) { + at old_value = std::move(obj); + obj = std::forward(new_value); + return old_value; +} + +#if defined(USEARCH_DEFINED_CPP20) + +template void destroy_at(at* obj) { std::destroy_at(obj); } +template void construct_at(at* obj) { std::construct_at(obj); } + +#else + +/// @brief The `std::destroy_at` alternative for C++11. +template +typename std::enable_if::value>::type destroy_at(at*) {} +template +typename std::enable_if::value>::type destroy_at(at* obj) { + obj->~sfinae_at(); +} + +/// @brief The `std::construct_at` alternative for C++11. +template +typename std::enable_if::value>::type construct_at(at*) {} +template +typename std::enable_if::value>::type construct_at(at* obj) { + new (obj) at(); +} + +#endif + +/** + * @brief A reference to a misaligned memory location with a specific type. + * It is needed to avoid Undefined Behavior when dereferencing addresses + * indivisible by `sizeof(at)`. + */ +template class misaligned_ref_gt { + using element_t = at; + using mutable_t = typename std::remove_const::type; + byte_t* ptr_; + + public: + misaligned_ref_gt(byte_t* ptr) noexcept : ptr_(ptr) {} + operator mutable_t() const noexcept { return misaligned_load(ptr_); } + misaligned_ref_gt& operator=(mutable_t const& v) noexcept { + misaligned_store(ptr_, v); + return *this; + } + + void reset(byte_t* ptr) noexcept { ptr_ = ptr; } + byte_t* ptr() const noexcept { return ptr_; } +}; + +/** + * @brief A pointer to a misaligned memory location with a specific type. + * It is needed to avoid Undefined Behavior when dereferencing addresses + * indivisible by `sizeof(at)`. + */ +template class misaligned_ptr_gt { + using element_t = at; + using mutable_t = typename std::remove_const::type; + byte_t* ptr_; + + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = misaligned_ptr_gt; + using reference = misaligned_ref_gt; + + misaligned_ptr_gt(byte_t* ptr) noexcept : ptr_(ptr) {} + + reference operator*() const noexcept { return {ptr_}; } + reference operator[](std::size_t i) noexcept { return reference(ptr_ + i * sizeof(element_t)); } + value_type operator[](std::size_t i) const noexcept { + return misaligned_load(ptr_ + i * sizeof(element_t)); + } + + misaligned_ptr_gt& operator++() noexcept { + ptr_ += sizeof(element_t); + return *this; + } + misaligned_ptr_gt& operator--() noexcept { + ptr_ -= sizeof(element_t); + return *this; + } + misaligned_ptr_gt operator++(int) noexcept { + misaligned_ptr_gt tmp = *this; + ++(*this); + return tmp; + } + misaligned_ptr_gt operator--(int) noexcept { + misaligned_ptr_gt tmp = *this; + --(*this); + return tmp; + } + misaligned_ptr_gt operator+(difference_type d) const noexcept { + return misaligned_ptr_gt(ptr_ + d * sizeof(element_t)); + } + misaligned_ptr_gt operator-(difference_type d) const noexcept { + return misaligned_ptr_gt(ptr_ - d * sizeof(element_t)); + } + difference_type operator-(const misaligned_ptr_gt& other) const noexcept { + return (ptr_ - other.ptr_) / sizeof(element_t); + } + + misaligned_ptr_gt& operator+=(difference_type d) noexcept { + ptr_ += d * sizeof(element_t); + return *this; + } + misaligned_ptr_gt& operator-=(difference_type d) noexcept { + ptr_ -= d * sizeof(element_t); + return *this; + } + + bool operator==(misaligned_ptr_gt const& other) const noexcept { return ptr_ == other.ptr_; } + bool operator!=(misaligned_ptr_gt const& other) const noexcept { return ptr_ != other.ptr_; } + bool operator<(misaligned_ptr_gt const& other) const noexcept { return ptr_ < other.ptr_; } + bool operator<=(misaligned_ptr_gt const& other) const noexcept { return ptr_ <= other.ptr_; } + bool operator>(misaligned_ptr_gt const& other) const noexcept { return ptr_ > other.ptr_; } + bool operator>=(misaligned_ptr_gt const& other) const noexcept { return ptr_ >= other.ptr_; } +}; + +/** + * @brief Non-owning memory range view, similar to `std::span`, but for C++11. + */ +template class span_gt { + scalar_at* data_; + std::size_t size_; + + public: + span_gt() noexcept : data_(nullptr), size_(0u) {} + span_gt(scalar_at* begin, scalar_at* end) noexcept : data_(begin), size_(end - begin) {} + span_gt(scalar_at* begin, std::size_t size) noexcept : data_(begin), size_(size) {} + scalar_at* data() const noexcept { return data_; } + std::size_t size() const noexcept { return size_; } + scalar_at* begin() const noexcept { return data_; } + scalar_at* end() const noexcept { return data_ + size_; } + operator scalar_at*() const noexcept { return data(); } +}; + +/** + * @brief Similar to `std::vector`, but doesn't support dynamic resizing. + * On the bright side, this can't throw exceptions. + */ +template > class buffer_gt { + scalar_at* data_; + std::size_t size_; + + public: + buffer_gt() noexcept : data_(nullptr), size_(0u) {} + buffer_gt(std::size_t size) noexcept : data_(allocator_at{}.allocate(size)), size_(data_ ? size : 0u) { + if (!std::is_trivially_default_constructible::value) + for (std::size_t i = 0; i != size_; ++i) + construct_at(data_ + i); + } + ~buffer_gt() noexcept { reset(); } + void reset() noexcept { + if (!std::is_trivially_destructible::value) + for (std::size_t i = 0; i != size_; ++i) + unum::usearch::destroy_at(data_ + i); //< Facing some symbol visibility/ambiguity issues + allocator_at{}.deallocate(data_, size_); + data_ = nullptr; + size_ = 0; + } + scalar_at* data() const noexcept { return data_; } + std::size_t size() const noexcept { return size_; } + scalar_at* begin() const noexcept { return data_; } + scalar_at* end() const noexcept { return data_ + size_; } + operator scalar_at*() const noexcept { return data(); } + scalar_at& operator[](std::size_t i) noexcept { return data_[i]; } + scalar_at const& operator[](std::size_t i) const noexcept { return data_[i]; } + explicit operator bool() const noexcept { return data_; } + scalar_at* release() noexcept { + size_ = 0; + return exchange(data_, nullptr); + } + + buffer_gt(buffer_gt const&) = delete; + buffer_gt& operator=(buffer_gt const&) = delete; + + buffer_gt(buffer_gt&& other) noexcept : data_(exchange(other.data_, nullptr)), size_(exchange(other.size_, 0)) {} + buffer_gt& operator=(buffer_gt&& other) noexcept { + std::swap(data_, other.data_); + std::swap(size_, other.size_); + return *this; + } +}; + +/** + * @brief A lightweight error class for handling error messages, + * which are expected to be allocated in static memory. + */ +class error_t { + char const* message_{}; + + public: + error_t() noexcept : message_(nullptr) {} + error_t(char const* message) noexcept : message_(message) {} + error_t& operator=(char const* message) noexcept { + message_ = message; + return *this; + } + + error_t(error_t const&) = delete; + error_t& operator=(error_t const&) = delete; + error_t(error_t&& other) noexcept : message_(exchange(other.message_, nullptr)) {} + error_t& operator=(error_t&& other) noexcept { + std::swap(message_, other.message_); + return *this; + } + + /// @brief Checks if there was an error. + explicit operator bool() const noexcept { return message_ != nullptr; } + + /// @brief Returns the error message. + char const* what() const noexcept { return message_; } + + /// @brief Releases the error message, meaning the caller takes ownership. + char const* release() noexcept { return exchange(message_, nullptr); } + +#if defined(__cpp_exceptions) || defined(__EXCEPTIONS) + /// @brief Destructor raises an exception if an error was recorded. + ~error_t() noexcept(false) { +#if defined(USEARCH_DEFINED_CPP17) + if (message_ && std::uncaught_exceptions() == 0) +#else + if (message_ && std::uncaught_exception() == 0) +#endif + raise(); + } + + /// @brief Throws an exception using to be caught by `try` / `catch`. + void raise() noexcept(false) { + if (message_) + throw std::runtime_error(exchange(message_, nullptr)); + } +#else + /// @brief Destructor terminates if an error was recorded. + ~error_t() noexcept { raise(); } + + /// @brief Terminates if an error was recorded. + void raise() noexcept { + if (message_) + std::terminate(); + } +#endif +}; + +/** + * @brief Similar to `std::expected` in C++23, wraps a statement evaluation result, + * or an error. It's used to avoid raising exception, and gracefully propagate + * the error. + * + * @tparam result_at The type of the expected result. + */ +template struct expected_gt { + result_at result; + error_t error; + + operator result_at&() & { + error.raise(); + return result; + } + operator result_at&&() && { + error.raise(); + return std::move(result); + } + result_at const& operator*() const noexcept { return result; } + explicit operator bool() const noexcept { return !error; } + expected_gt failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } +}; + +/** + * @brief Light-weight bitset implementation to sync nodes updates during graph mutations. + * Extends basic functionality with @b atomic operations. + */ +template > class bitset_gt { + using allocator_t = allocator_at; + using byte_t = typename allocator_t::value_type; + static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); + + using compressed_slot_t = unsigned long; + + static constexpr std::size_t bits_per_slot() { return sizeof(compressed_slot_t) * CHAR_BIT; } + static constexpr compressed_slot_t bits_mask() { return sizeof(compressed_slot_t) * CHAR_BIT - 1; } + static constexpr std::size_t slots(std::size_t bits) { return divide_round_up(bits); } + + compressed_slot_t* slots_{}; + /// @brief Number of slots. + std::size_t count_{}; + + public: + bitset_gt() noexcept {} + ~bitset_gt() noexcept { reset(); } + + explicit operator bool() const noexcept { return slots_; } + void clear() noexcept { + if (slots_) + std::memset(slots_, 0, count_ * sizeof(compressed_slot_t)); + } + + void reset() noexcept { + if (slots_) + allocator_t{}.deallocate((byte_t*)slots_, count_ * sizeof(compressed_slot_t)); + slots_ = nullptr; + count_ = 0; + } + + bitset_gt(std::size_t capacity) noexcept + : slots_((compressed_slot_t*)allocator_t{}.allocate(slots(capacity) * sizeof(compressed_slot_t))), + count_(slots_ ? slots(capacity) : 0u) { + clear(); + } + + bitset_gt(bitset_gt&& other) noexcept { + slots_ = exchange(other.slots_, nullptr); + count_ = exchange(other.count_, 0); + } + + bitset_gt& operator=(bitset_gt&& other) noexcept { + std::swap(slots_, other.slots_); + std::swap(count_, other.count_); + return *this; + } + + bitset_gt(bitset_gt const&) = delete; + bitset_gt& operator=(bitset_gt const&) = delete; + + inline bool test(std::size_t i) const noexcept { return slots_[i / bits_per_slot()] & (1ul << (i & bits_mask())); } + inline bool set(std::size_t i) noexcept { + compressed_slot_t& slot = slots_[i / bits_per_slot()]; + compressed_slot_t mask{1ul << (i & bits_mask())}; + bool value = slot & mask; + slot |= mask; + return value; + } + +#if defined(USEARCH_DEFINED_WINDOWS) + + inline bool atomic_set(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + return InterlockedOr((long volatile*)&slots_[i / bits_per_slot()], mask) & mask; + } + + inline void atomic_reset(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + InterlockedAnd((long volatile*)&slots_[i / bits_per_slot()], ~mask); + } + +#else + + inline bool atomic_set(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + return __atomic_fetch_or(&slots_[i / bits_per_slot()], mask, __ATOMIC_ACQUIRE) & mask; + } + + inline void atomic_reset(std::size_t i) noexcept { + compressed_slot_t mask{1ul << (i & bits_mask())}; + __atomic_fetch_and(&slots_[i / bits_per_slot()], ~mask, __ATOMIC_RELEASE); + } + +#endif + + class lock_t { + bitset_gt& bitset_; + std::size_t bit_offset_; + + public: + inline ~lock_t() noexcept { bitset_.atomic_reset(bit_offset_); } + inline lock_t(bitset_gt& bitset, std::size_t bit_offset) noexcept : bitset_(bitset), bit_offset_(bit_offset) { + while (bitset_.atomic_set(bit_offset_)) + ; + } + }; + + inline lock_t lock(std::size_t i) noexcept { return {*this, i}; } +}; + +using bitset_t = bitset_gt<>; + +/** + * @brief Similar to `std::priority_queue`, but allows raw access to underlying + * memory, in case you want to shuffle it or sort. Good for collections + * from 100s to 10'000s elements. + * + * In a max-heap, the heap property ensures that the value of each node is greater + * than or equal to the values of its children. This means that the largest element + * is always at the root of the heap. + * + * @section Heap Structures + * + * There are several designs of heaps. Binary heaps are the simplest & most common + * variant, that is easy to implement as a succint array. However, they are not the + * most efficient for all operations. Most importantly, @b melding (merging) of + * two heaps has linear complexity in time. + * + * +-----------------+---------+-----------+---------+--------------+---------+ + * | Operation | find-max| delete-max| insert | increase-key | meld | + * +-----------------+---------+-----------+---------+--------------+---------+ + * | Binary | Θ(1) | Θ(log n) | O(log n)| O(log n) | Θ(n) | + * | Leftist | Θ(1) | Θ(log n) | O(log n)| Θ(log n) | Θ(log n)| + * | Binomial | Θ(1) | Θ(log n) | Θ(1) | Θ(log n) | O(log n)| + * | Skew binomial | Θ(1) | Θ(log n) | Θ(1) | O(log n) | O(log n)| + * | Pairing | Θ(1) | O(log n) | Θ(1) | o(log n) | Θ(1) | + * | Rank-pairing | Θ(1) | O(log n) | Θ(1) | Θ(1) | Θ(1) | + * | Fibonacci | Θ(1) | O(log n) | Θ(1) | Θ(1) | Θ(1) | + * | Strict Fibonacci| Θ(1) | O(log n) | Θ(1) | Θ(1) | Θ(1) | + * | Brodal | Θ(1) | Θ(log n) | Θ(1) | Θ(1) | Θ(1) | + * | 2–3 heap | Θ(1) | O(log n) | Θ(1) | Θ(1) | O(log n)| + * +-----------------+---------+-----------+---------+--------------+---------+ + * + * It's well known, that improved priority queue structures translate into better + * graph-transversal algorithms. For example, Dijkstra's algorithm can be sped up + * by using a Fibonacci heap for arbitrary weights. For integer weight bounded + * by L, Schrijver reported following time complexities in 2004: + * + * +------------+-------------------------------------+----------------------------+--------------------------+ + * | Weights | Algorithm | Time complexity | Author | + * +------------+-------------------------------------+----------------------------+--------------------------+ + * | R | | O(V^2 EL) | Ford 1956 | + * | R | Bellman–Ford algorithm | O(VE) | Shimbel 1955, Bellman | + * | | | | 1958, Moore 1959 | + * | R | | O(V^2 log V) | Dantzig 1960 | + * | R | Dijkstra's with list | O(V^2) | Leyzorek et al. 1957, | + * | | | | Dijkstra 1959... | + * | R | Dijkstra's with binary heap | O((E + V) log V) | Johnson 1977 | + * | R | Dijkstra's with Fibonacci heap | O(E + V log V) | Fredman & Tarjan 1984, | + * | | | | Fredman & Tarjan 1987 | + * | R | Quantum Dijkstra | O(√VE log^2 V) | Dürr et al. 2006 | + * | R | Dial's algorithm (Dijkstra's using | O(E + LV) | Dial 1969 | + * | | a bucket queue with L buckets) | | | + * | N | | O(E log log L) | Johnson 1981, Karlsson & | + * | | | | Poblete 1983 | + * | N | Gabow's algorithm | O(E log_E/V L) | Gabow 1983, Gabow 1985 | + * | N | | O(E + V √log L) | Ahuja et al. 1990 | + * | N | Thorup | O(E + V log log V) | Thorup 2004 | + * +------------+-------------------------------------+----------------------------+--------------------------+ + * + * Possible improvements: + * - Randomized meldable heaps: https://en.wikipedia.org/wiki/Randomized_meldable_heap + * - D-ary heaps: https://en.wikipedia.org/wiki/D-ary_heap + * - B-heap: https://en.wikipedia.org/wiki/B-heap + */ +template , // is needed before C++14. + typename allocator_at = std::allocator> // +class max_heap_gt { + public: + using element_t = element_at; + using comparator_t = comparator_at; + using allocator_t = allocator_at; + + using value_type = element_t; + + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + + private: + element_t* elements_; + std::size_t size_; + std::size_t capacity_; + + public: + max_heap_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) {} + + max_heap_gt(max_heap_gt&& other) noexcept + : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), + capacity_(exchange(other.capacity_, 0)) {} + + max_heap_gt& operator=(max_heap_gt&& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + return *this; + } + + max_heap_gt(max_heap_gt const&) = delete; + max_heap_gt& operator=(max_heap_gt const&) = delete; + + ~max_heap_gt() noexcept { reset(); } + + void reset() noexcept { + if (elements_) + allocator_t{}.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + size_ = 0; + } + + inline bool empty() const noexcept { return !size_; } + inline std::size_t size() const noexcept { return size_; } + inline std::size_t capacity() const noexcept { return capacity_; } + inline element_t* data() noexcept { return elements_; } + inline element_t const* data() const noexcept { return elements_; } + inline void clear() noexcept { size_ = 0; } + inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } + + /// @brief Selects the largest element in the heap. + /// @return Reference to the stored element. + inline element_t const& top() const noexcept { return elements_[0]; } + + /// @brief Invalidates the "max-heap" property, transforming into ascending range. + inline void sort_ascending() noexcept { std::sort_heap(elements_, elements_ + size_, &less); } + + /** + * @brief Ensures the heap has enough capacity for the specified number of elements. + * @param new_capacity The desired minimum capacity. + * @return True if the capacity was successfully increased, false otherwise. + */ + usearch_profiled_m bool reserve(std::size_t new_capacity) noexcept { + usearch_profile_name_m(max_heap_reserve); + if (new_capacity < capacity_) + return true; + + new_capacity = ceil2(new_capacity); + new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); + auto allocator = allocator_t{}; + auto new_elements = allocator.allocate(new_capacity); + if (!new_elements) + return false; + + if (elements_) { + std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); + allocator.deallocate(elements_, capacity_); + } + elements_ = new_elements; + capacity_ = new_capacity; + return new_elements; + } + + /** + * @brief Inserts an element into the heap. + * @param element The element to be inserted. + * @return True if the element was successfully inserted, false otherwise. + */ + bool insert(element_t&& element) noexcept { + if (!reserve(size_ + 1)) + return false; + + insert_reserved(std::move(element)); + return true; + } + + /** + * @brief Inserts an element into the heap without reserving additional space. + * @param element The element to be inserted. + */ + usearch_profiled_m void insert_reserved(element_t&& element) noexcept { + usearch_profile_name_m(max_heap_insert_reserved); + new (&elements_[size_]) element_t(element); + size_++; + shift_up(size_ - 1); + } + + /** + * @brief Inserts multiple elements into the heap. + * @param elements Pointer to the elements to be inserted. + * @return True if the elements were successfully inserted, false otherwise. + */ + inline bool insert_many(element_t const* elements) noexcept { + // Wikipedia describes a procedure, due to Floyd, which constructs a heap from an array in linear time. + // It also mentions a procedure for merging two heaps, of sizes 𝑛 and 𝑘, in time 𝑂(𝑘+log𝑘log𝑛). + // Altogether, we can add 𝑘 elements to a heap of length 𝑛 in time 𝑂(𝑘+log𝑘log𝑛): first build a heap containing + // 𝑘 elements to be inserted (takes 𝑂(𝑘) time), then merge that with the heap of size 𝑛 (takes 𝑂(𝑘+log𝑘log𝑛) + // time). Compare this to repeated insertion, which would run in time 𝑂(𝑘log𝑛). + return false; + } + + usearch_profiled_m element_t pop() noexcept { + usearch_profile_name_m(max_heap_pop); + element_t result = top(); + std::swap(elements_[0], elements_[size_ - 1]); + size_--; + elements_[size_].~element_t(); + shift_down(0); + return result; + } + + private: + static std::size_t parent_idx(std::size_t i) noexcept { return (i - 1u) / 2u; } + static std::size_t left_child_idx(std::size_t i) noexcept { return (i * 2u) + 1u; } + static std::size_t right_child_idx(std::size_t i) noexcept { return (i * 2u) + 2u; } + static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } + + /** + * @brief Shifts an element up to maintain the heap property. + * This operation is called when a new element is @b added at the end of the heap. + * The element is moved up until the heap property is restored. + * @param i Index of the element to be shifted up. + */ + void shift_up(std::size_t i) noexcept { + for (; i && less(elements_[parent_idx(i)], elements_[i]); i = parent_idx(i)) + std::swap(elements_[parent_idx(i)], elements_[i]); + } + + /** + * @brief Shifts an element down to maintain the heap property. + * This operation is called when the root element is @b removed and the last element is moved to the root. + * The element is moved down until the heap property is restored. + * @param i Index of the element to be shifted down. + */ + void shift_down(std::size_t i) noexcept { + std::size_t max_idx = i; + + std::size_t left = left_child_idx(i); + if (left < size_ && less(elements_[max_idx], elements_[left])) + max_idx = left; + + std::size_t right = right_child_idx(i); + if (right < size_ && less(elements_[max_idx], elements_[right])) + max_idx = right; + + if (i != max_idx) { + std::swap(elements_[i], elements_[max_idx]); + shift_down(max_idx); + } + } +}; + +/** + * @brief Similar to `std::priority_queue`, but allows raw access to underlying + * memory and always keeps the data sorted. Ideal for small collections + * under 128 elements. + */ +template , // is needed before C++14. + typename allocator_at = std::allocator> // +class sorted_buffer_gt { + public: + using element_t = element_at; + using comparator_t = comparator_at; + using allocator_t = allocator_at; + + static_assert(std::is_trivially_destructible(), "This heap is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This heap is designed for trivial structs"); + + using value_type = element_t; + + private: + element_t* elements_; + std::size_t size_; + std::size_t capacity_; + + public: + sorted_buffer_gt() noexcept : elements_(nullptr), size_(0), capacity_(0) {} + + sorted_buffer_gt(sorted_buffer_gt&& other) noexcept + : elements_(exchange(other.elements_, nullptr)), size_(exchange(other.size_, 0)), + capacity_(exchange(other.capacity_, 0)) {} + + sorted_buffer_gt& operator=(sorted_buffer_gt&& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(size_, other.size_); + std::swap(capacity_, other.capacity_); + return *this; + } + + sorted_buffer_gt(sorted_buffer_gt const&) = delete; + sorted_buffer_gt& operator=(sorted_buffer_gt const&) = delete; + + ~sorted_buffer_gt() noexcept { reset(); } + + void reset() noexcept { + if (elements_) + allocator_t{}.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + size_ = 0; + } + + inline bool empty() const noexcept { return !size_; } + inline std::size_t size() const noexcept { return size_; } + inline std::size_t capacity() const noexcept { return capacity_; } + inline element_t const& top() const noexcept { return elements_[size_ - 1]; } + inline void clear() noexcept { size_ = 0; } + + bool reserve(std::size_t new_capacity) noexcept { + if (new_capacity < capacity_) + return true; + + new_capacity = ceil2(new_capacity); + new_capacity = (std::max)(new_capacity, (std::max)(capacity_ * 2u, 16u)); + auto allocator = allocator_t{}; + auto new_elements = allocator.allocate(new_capacity); + if (!new_elements) + return false; + + if (size_) + std::memcpy(new_elements, elements_, size_ * sizeof(element_t)); + if (elements_) + allocator.deallocate(elements_, capacity_); + + elements_ = new_elements; + capacity_ = new_capacity; + return true; + } + + inline void insert_reserved(element_t&& element) noexcept { + std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; + std::size_t to_move = size_ - slot; + element_t* source = elements_ + size_ - 1; + for (; to_move; --to_move, --source) + source[1] = source[0]; + elements_[slot] = element; + size_++; + } + + /** + * @return `true` if the entry was added, `false` if it wasn't relevant enough. + */ + inline bool insert(element_t&& element, std::size_t limit) noexcept { + std::size_t slot = size_ ? std::lower_bound(elements_, elements_ + size_, element, &less) - elements_ : 0; + if (slot == limit) + return false; + std::size_t to_move = size_ - slot - (size_ == limit); + element_t* source = elements_ + size_ - 1 - (size_ == limit); + for (; to_move; --to_move, --source) + source[1] = source[0]; + elements_[slot] = element; + size_ += size_ != limit; + return true; + } + + inline element_t pop() noexcept { + size_--; + element_t result = elements_[size_]; + elements_[size_].~element_t(); + return result; + } + + void sort_ascending() noexcept {} + inline void shrink(std::size_t n) noexcept { size_ = (std::min)(n, size_); } + + inline element_t* data() noexcept { return elements_; } + inline element_t const* data() const noexcept { return elements_; } + + private: + static bool less(element_t const& a, element_t const& b) noexcept { return comparator_t{}(a, b); } +}; + +#if defined(USEARCH_DEFINED_WINDOWS) +#pragma pack(push, 1) // Pack struct elements on 1-byte alignment +#endif + +/** + * @brief Five-byte integer type to address node clouds with over 4B entries. + * + * 40 bits is enough to address a @b Trillion entries potentially colocated on 1 machine. + * At roughly 5 bytes * 20 neighbors + 100 bytes per entry, this translates to 200 TB of data, + * which is similar to a single-server capacity of modern NVME arrays. + */ +class usearch_pack_m uint40_t { + unsigned char octets[5]; + + inline uint40_t& broadcast(unsigned char c) { + std::memset(octets, c, 5); + return *this; + } + + public: + inline uint40_t() noexcept { broadcast(0); } + inline uint40_t(std::uint32_t n) noexcept { + std::memcpy(&octets, &n, 4); + octets[4] = 0; + } + +#ifdef USEARCH_64BIT_ENV + inline uint40_t(std::uint64_t n) noexcept { std::memcpy(octets, &n, 5); } +#endif + + uint40_t(uint40_t&&) = default; + uint40_t(uint40_t const&) = default; + uint40_t& operator=(uint40_t&&) = default; + uint40_t& operator=(uint40_t const&) = default; + +#if defined(USEARCH_DEFINED_CLANG) && defined(USEARCH_DEFINED_APPLE) + inline uint40_t(std::size_t n) noexcept { +#ifdef USEARCH_64BIT_ENV + std::memcpy(octets, &n, 5); +#else + std::memcpy(octets, &n, 4); + octets[4] = 0; +#endif // USEARCH_64BIT_ENV + } +#endif // USEARCH_DEFINED_CLANG && USEARCH_DEFINED_APPLE + + inline operator std::size_t() const noexcept { + std::size_t result = 0; +#ifdef USEARCH_64BIT_ENV + std::memcpy(&result, octets, 5); +#else + std::memcpy(&result, octets, 4); +#endif + return result; + } + + inline static uint40_t max() noexcept { return uint40_t{}.broadcast(0xFF); } + inline static uint40_t min() noexcept { return uint40_t{}.broadcast(0); } + + inline bool operator==(uint40_t const& other) const noexcept { return std::memcmp(octets, other.octets, 5) == 0; } + inline bool operator!=(uint40_t const& other) const noexcept { return !(*this == other); } + inline bool operator>(uint40_t const& other) const noexcept { return other < *this; } + inline bool operator<=(uint40_t const& other) const noexcept { return !(*this > other); } + inline bool operator>=(uint40_t const& other) const noexcept { return !(*this < other); } + inline bool operator<(uint40_t const& other) const noexcept { + for (int i = 0; i < 5; ++i) { + if (octets[4 - i] < other.octets[4 - i]) + return true; + if (octets[4 - i] > other.octets[4 - i]) + return false; + } + return false; + } +}; + +#if defined(USEARCH_DEFINED_WINDOWS) +#pragma pack(pop) // Reset alignment to default +#endif + +static_assert(sizeof(uint40_t) == 5, "uint40_t must be exactly 5 bytes"); + +/** + * @brief Reflection-helper to get the default "unused" value for a given type. + * Needed to initialize hash-sets and bit-sets. + */ +template struct default_free_value_gt { + template ::value>::type* = nullptr> + static sfinae_element_at value() noexcept { + return std::numeric_limits::max(); + } + template ::value>::type* = nullptr> + static sfinae_element_at value() noexcept { + return element_at(); + } +}; + +template <> struct default_free_value_gt { + static uint40_t value() noexcept { return uint40_t::max(); } +}; + +template element_at default_free_value() { return default_free_value_gt::value(); } + +/** + * @brief Adapter to allow definining arbitrary hash functions for keys and slots. + * It's added, as overloading `std::hash` is not recommended by the standard. + */ +template struct hash_gt { + std::size_t operator()(element_at const& element) const noexcept { return std::hash{}(element); } +}; + +template <> struct hash_gt { + std::size_t operator()(uint40_t const& element) const noexcept { return std::hash{}(element); } +}; + +/** + * @brief Minimalistic hash-set implementation to track visited nodes during graph traversal. + * In our primary usecase, its a sparse alternative to a bit-set. + * + * It doesn't support deletion of separate objects, but supports `clear`-ing all at once. + * It expects `reserve` to be called ahead of all insertions, so no resizes are needed. + * It also assumes `0xFF...FF` slots to be unused, to simplify the design. + * It uses linear probing, the number of slots is always a power of two, and it uses linear-probing + * in case of bucket collisions. + */ +template , typename allocator_at = std::allocator> +class growing_hash_set_gt { + + using element_t = element_at; + using hasher_t = hasher_at; + + using allocator_t = allocator_at; + using byte_t = typename allocator_t::value_type; + static_assert(sizeof(byte_t) == 1, "Allocator must allocate separate addressable bytes"); + + element_t* slots_{}; + /// @brief Number of slots. + std::size_t capacity_{}; + /// @brief Number of populated. + std::size_t count_{}; + hasher_t hasher_{}; + + public: + growing_hash_set_gt() noexcept {} + ~growing_hash_set_gt() noexcept { reset(); } + + explicit operator bool() const noexcept { return slots_; } + std::size_t size() const noexcept { return count_; } + + void clear() noexcept { + if (slots_) + std::memset((void*)slots_, 0xFF, capacity_ * sizeof(element_t)); + count_ = 0; + } + + void reset() noexcept { + if (slots_) + allocator_t{}.deallocate((byte_t*)slots_, capacity_ * sizeof(element_t)); + slots_ = nullptr; + capacity_ = 0; + count_ = 0; + } + + growing_hash_set_gt(std::size_t capacity) noexcept + : slots_((element_t*)allocator_t{}.allocate(ceil2(capacity) * sizeof(element_t))), + capacity_(slots_ ? ceil2(capacity) : 0u), count_(0u) { + clear(); + } + + growing_hash_set_gt(growing_hash_set_gt&& other) noexcept { + slots_ = exchange(other.slots_, nullptr); + capacity_ = exchange(other.capacity_, 0); + count_ = exchange(other.count_, 0); + } + + growing_hash_set_gt& operator=(growing_hash_set_gt&& other) noexcept { + std::swap(slots_, other.slots_); + std::swap(capacity_, other.capacity_); + std::swap(count_, other.count_); + return *this; + } + + growing_hash_set_gt(growing_hash_set_gt const&) = delete; + growing_hash_set_gt& operator=(growing_hash_set_gt const&) = delete; + + /** + * @brief Checks if the element is already in the hash-set. + * @return `true` if the element is already in the hash-set. + */ + inline bool test(element_t const& elem) const noexcept { + std::size_t index = hasher_(elem) & (capacity_ - 1); + while (slots_[index] != default_free_value()) { + if (slots_[index] == elem) + return true; + + index = (index + 1) & (capacity_ - 1); + } + return false; + } + + /** + * @brief Inserts an element into the hash-set. + * @return Similar to `bitset_gt`, returns the previous value. + */ + inline bool set(element_t const& elem) noexcept { + std::size_t index = hasher_(elem) & (capacity_ - 1); + while (slots_[index] != default_free_value()) { + // Already exists + if (slots_[index] == elem) + return true; + + index = (index + 1) & (capacity_ - 1); + } + slots_[index] = elem; + ++count_; + return false; + } + + /** + * @brief Extends the capacity of the hash-set. + * @return `true` if enough capacity is available, `false` if memory allocation failed. + */ + bool reserve(std::size_t new_capacity) noexcept { + new_capacity = (new_capacity * 5u) / 3u; + if (new_capacity <= capacity_) + return true; + + new_capacity = ceil2(new_capacity); + element_t* new_slots = (element_t*)allocator_t{}.allocate(new_capacity * sizeof(element_t)); + if (!new_slots) + return false; + + std::memset((void*)new_slots, 0xFF, new_capacity * sizeof(element_t)); + std::size_t new_count = count_; + if (count_) { + for (std::size_t old_index = 0; old_index != capacity_; ++old_index) { + if (slots_[old_index] == default_free_value()) + continue; + + std::size_t new_index = hasher_(slots_[old_index]) & (new_capacity - 1); + while (new_slots[new_index] != default_free_value()) + new_index = (new_index + 1) & (new_capacity - 1); + new_slots[new_index] = slots_[old_index]; + } + } + + reset(); + slots_ = new_slots; + capacity_ = new_capacity; + count_ = new_count; + return true; + } +}; + +/** + * @brief Basic single-threaded @b ring class, used for all kinds of task queues. + */ +template > // +class ring_gt { + public: + using element_t = element_at; + using allocator_t = allocator_at; + + static_assert(std::is_trivially_destructible(), "This ring is designed for trivial structs"); + static_assert(std::is_trivially_copy_constructible(), "This ring is designed for trivial structs"); + + using value_type = element_t; + + private: + element_t* elements_{}; + std::size_t capacity_{}; + std::size_t head_{}; + std::size_t tail_{}; + bool empty_{true}; + allocator_t allocator_{}; + + public: + explicit ring_gt(allocator_t const& alloc = allocator_t()) noexcept : allocator_(alloc) {} + + ring_gt(ring_gt const&) = delete; + ring_gt& operator=(ring_gt const&) = delete; + + ring_gt(ring_gt&& other) noexcept { swap(other); } + ring_gt& operator=(ring_gt&& other) noexcept { + swap(other); + return *this; + } + + void swap(ring_gt& other) noexcept { + std::swap(elements_, other.elements_); + std::swap(capacity_, other.capacity_); + std::swap(head_, other.head_); + std::swap(tail_, other.tail_); + std::swap(empty_, other.empty_); + std::swap(allocator_, other.allocator_); + } + + ~ring_gt() noexcept { reset(); } + + bool empty() const noexcept { return empty_; } + size_t capacity() const noexcept { return capacity_; } + size_t size() const noexcept { + if (empty_) + return 0; + else if (head_ > tail_) + return head_ - tail_; + else + return capacity_ - (tail_ - head_); + } + + void clear() noexcept { + head_ = 0; + tail_ = 0; + empty_ = true; + } + + void reset() noexcept { + if (elements_) + allocator_.deallocate(elements_, capacity_); + elements_ = nullptr; + capacity_ = 0; + head_ = 0; + tail_ = 0; + empty_ = true; + } + + bool reserve(std::size_t n) noexcept { + if (n < size()) + return false; // prevent data loss + if (n <= capacity()) + return true; + n = (std::max)(ceil2(n), 64u); + element_t* elements = allocator_.allocate(n); + if (!elements) + return false; + + std::size_t i = 0; + while (try_pop(elements[i])) + i++; + + reset(); + elements_ = elements; + capacity_ = n; + head_ = i; + tail_ = 0; + empty_ = (i == 0); + return true; + } + + void push(element_t const& value) usearch_noexcept_m { + usearch_assert_m(capacity() > 0, "Ring buffer is not initialized"); + usearch_assert_m(size() < capacity(), "Ring buffer is full"); + elements_[head_] = value; + head_ = (head_ + 1) % capacity_; + empty_ = false; + } + + bool try_push(element_t const& value) noexcept { + if (head_ == tail_ && !empty_) + return false; // `elements_` is full + + return push(value); + return true; + } + + bool try_pop(element_t& value) noexcept { + if (empty_) + return false; + + value = std::move(elements_[tail_]); + tail_ = (tail_ + 1) % capacity_; + empty_ = head_ == tail_; + return true; + } + + element_t const& operator[](std::size_t i) const noexcept { return elements_[(tail_ + i) % capacity_]; } +}; + +/// @brief Number of neighbors per graph node. +/// Defaults to 32 in FAISS and 16 in hnswlib. +/// > It is called `M` in the paper. +constexpr std::size_t default_connectivity() { return 16; } + +/// @brief Hyper-parameter controlling the quality of indexing. +/// Defaults to 40 in FAISS and 200 in hnswlib. +/// > It is called `efConstruction` in the paper. +constexpr std::size_t default_expansion_add() { return 128; } + +/// @brief Hyper-parameter controlling the quality of search. +/// Defaults to 16 in FAISS and 10 in hnswlib. +/// > It is called `ef` in the paper. +constexpr std::size_t default_expansion_search() { return 64; } + +constexpr std::size_t default_allocator_entry_bytes() { return 64; } + +/** + * @brief Configuration settings for the index construction. + * Includes the main `::connectivity` parameter (`M` in the paper) + * and two expansion factors - for construction and search. + */ +struct index_config_t { + /// @brief Number of neighbors per graph node. + /// Defaults to 32 in FAISS and 16 in hnswlib. + /// > It is called `M` in the paper. + std::size_t connectivity = default_connectivity(); + + /// @brief Number of neighbors per graph node in base level graph. + /// Defaults to double of the other levels, so 64 in FAISS and 32 in hnswlib. + /// > It is called `M0` in the paper. + std::size_t connectivity_base = default_connectivity() * 2; + + inline index_config_t() = default; + inline index_config_t(std::size_t c, std::size_t cb = 0) noexcept : connectivity(c), connectivity_base(cb) {} + + /** + * @brief Validates the configuration settings, updating them in-place. + * @return Error message, if any. + */ + inline error_t validate() noexcept { + if (connectivity == 0) + connectivity = default_connectivity(); + if (connectivity_base == 0) + connectivity_base = connectivity * 2; + if (connectivity < 2) + return "Connectivity must be at least 2, otherwise the index degenerates into ropes"; + if (connectivity_base < connectivity) + return "Base layer should be at least as connected as the rest of the graph"; + return {}; + } + + /** + * @brief Immutable function to check if the configuration is valid. + * @return `true` if the configuration is valid. + */ + inline bool is_valid() const noexcept { return connectivity >= 2 && connectivity_base >= connectivity; } +}; + +/** + * @brief Growth settings for the index container. + * Includes the upper bound for `::members` capacity, + * and the number of read/write threads expected to work with the index. + */ +struct index_limits_t { + /// @brief Maximum number of entries in the index. + std::size_t members = 0; + /// @brief Max number of threads simultaneously updating entries. + std::size_t threads_add = std::thread::hardware_concurrency(); + /// @brief Max number of threads simultaneously searching entries. + std::size_t threads_search = std::thread::hardware_concurrency(); + + inline index_limits_t(std::size_t n, std::size_t t) noexcept : members(n), threads_add(t), threads_search(t) {} + inline index_limits_t(std::size_t n = 0) noexcept : index_limits_t(n, std::thread::hardware_concurrency()) {} + /// @brief Returns the upper limit for the number of threads. + inline std::size_t threads() const noexcept { return (std::max)(threads_add, threads_search); } + /// @brief Returns the concurrency-level of the index - the minimum of thread counts. + inline std::size_t concurrency() const noexcept { return (std::min)(threads_add, threads_search); } +}; + +struct index_update_config_t { + /// @brief Hyper-parameter controlling the quality of indexing. + /// Defaults to 40 in FAISS and 200 in hnswlib. + /// > It is called `efConstruction` in the paper. + std::size_t expansion = default_expansion_add(); + + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; +}; + +struct index_search_config_t { + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); + + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; + + /// @brief Brute-forces exhaustive search over all entries in the index. + bool exact = false; +}; + +struct index_cluster_config_t { + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); + + /// @brief Optional thread identifier for multi-threaded construction. + std::size_t thread = 0; +}; + +struct index_copy_config_t {}; + +struct index_join_config_t { + /// @brief Controls maximum number of proposals per man during stable marriage. + std::size_t max_proposals = 0; + + /// @brief Hyper-parameter controlling the quality of search. + /// Defaults to 16 in FAISS and 10 in hnswlib. + /// > It is called `ef` in the paper. + std::size_t expansion = default_expansion_search(); + + /// @brief Brute-forces exhaustive search over all entries in the index. + bool exact = false; +}; + +/// @brief C++17 and newer version deprecate the `std::result_of` +template +using return_type_gt = +#if defined(USEARCH_DEFINED_CPP17) + typename std::invoke_result::type; +#else + typename std::result_of::type; +#endif + +/** + * @brief An example of what a USearch-compatible ad-hoc filter would look like. + * + * A similar function object can be passed to search queries to further filter entries + * on their auxiliary properties, such as some categorical keys stored in an external DBMS. + */ +struct dummy_predicate_t { + template constexpr bool operator()(member_at&&) const noexcept { return true; } +}; + +/** + * @brief An example of what a USearch-compatible ad-hoc operation on in-flight entries. + * + * This kind of callbacks is used when the engine is being updated and you want to patch + * the entries, while their are still under locks - limiting concurrent access and providing + * consistency. + */ +struct dummy_callback_t { + template void operator()(member_at&&) const noexcept {} +}; + +/** + * @brief An example of what a USearch-compatible progress-bar should look like. + * + * This is particularly helpful when handling long-running tasks, like serialization, + * saving, and loading from disk, or index-level joins. + * The reporter checks return value to continue or stop the process, `false` means need to stop. + */ +struct dummy_progress_t { + inline bool operator()(std::size_t /*processed*/, std::size_t /*total*/) const noexcept { return true; } +}; + +/** + * @brief An example of what a USearch-compatible values prefetching mechanism should look like. + * + * USearch is designed to handle very large datasets, that may not fir into RAM. Fetching from + * external memory is very expensive, so we've added a pre-fetching mechanism, that accepts + * multiple objects at once, to cache in RAM ahead of the computation. + * The received iterators support both `get_slot` and `get_key` operations. + * An example usage may look like this: + * + * template + * inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept { + * for (; begin != end; ++begin) + * io_uring_prefetch(offset_in_file(get_key(begin))); + * } + */ +struct dummy_prefetch_t { + template + inline void operator()(member_citerator_like_at, member_citerator_like_at) const noexcept {} +}; + +/** + * @brief An example of what a USearch-compatible executor (thread-pool) should look like. + * + * It's expected to have `parallel(callback)` API to schedule one task per thread; + * an identical `fixed(count, callback)` and `dynamic(count, callback)` overloads that also accepts + * the number of tasks, and somehow schedules them between threads; as well as `size()` to + * determine the number of available threads. + */ +struct dummy_executor_t { + dummy_executor_t() noexcept {} + std::size_t size() const noexcept { return 1; } + + template + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept { + for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) + thread_aware_function(0, task_idx); + } + + template + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept { + for (std::size_t task_idx = 0; task_idx != tasks; ++task_idx) + if (!thread_aware_function(0, task_idx)) + break; + } + + template + void parallel(thread_aware_function_at&& thread_aware_function) noexcept { + thread_aware_function(0); + } +}; + +/** + * @brief An example of what a USearch-compatible key-to-key mapping should look like. + * + * This is particularly helpful for "Semantic Joins", where we map entries of one collection + * to entries of another. In asymmetric setups, where A -> B is needed, but B -> A is not, + * this can be passed to minimize memory usage. + */ +struct dummy_key_to_key_mapping_t { + struct member_ref_t { + template member_ref_t& operator=(key_at&&) noexcept { return *this; } + }; + template member_ref_t operator[](key_at&&) const noexcept { return {}; } +}; + +/** + * @brief Checks if the provided object has a dummy type, emulating an interface, + * but performing no real computation. + */ +template static constexpr bool is_dummy() { + using object_t = typename std::remove_all_extents::type; + return std::is_same::type, dummy_predicate_t>::value || // + std::is_same::type, dummy_callback_t>::value || // + std::is_same::type, dummy_progress_t>::value || // + std::is_same::type, dummy_prefetch_t>::value || // + std::is_same::type, dummy_executor_t>::value || // + std::is_same::type, dummy_key_to_key_mapping_t>::value; +} + +template struct has_reset_gt { + static_assert(std::integral_constant::value, "Second template parameter needs to be of function type."); +}; + +template +struct has_reset_gt { + private: + template + static constexpr auto check(at*) -> + typename std::is_same().reset(std::declval()...)), return_at>::type; + template static constexpr std::false_type check(...); + + typedef decltype(check(0)) type; + + public: + static constexpr bool value = type::value; +}; + +/** + * @brief Checks if a certain class has a member function called `reset`. + */ +template constexpr bool has_reset() { return has_reset_gt::value; } + +struct serialization_result_t { + error_t error; + + explicit operator bool() const noexcept { return !error; } + serialization_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } +}; + +/** + * @brief Smart-pointer wrapping the LibC @b `FILE` for binary file @b outputs. + * + * This class raises no exceptions and corresponds errors through `serialization_result_t`. + * The class automatically closes the file when the object is destroyed. + */ +class output_file_t { + char const* path_ = nullptr; + std::FILE* file_ = nullptr; + + public: + output_file_t(char const* path) noexcept : path_(path) {} + ~output_file_t() noexcept { close(); } + output_file_t(output_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) {} + output_file_t& operator=(output_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(file_, other.file_); + return *this; + } + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!file_) + file_ = std::fopen(path_, "wb"); + if (!file_) + return result.failed(std::strerror(errno)); + return result; + } + serialization_result_t write(void const* begin, std::size_t length) noexcept { + serialization_result_t result; + std::size_t written = std::fwrite(begin, length, 1, file_); + if (length && !written) + return result.failed(std::strerror(errno)); + return result; + } + void close() noexcept { + if (file_) + std::fclose(exchange(file_, nullptr)); + } +}; + +/** + * @brief Smart-pointer wrapping the LibC @b `FILE` for binary files @b inputs. + * + * This class raises no exceptions and corresponds errors through `serialization_result_t`. + * The class automatically closes the file when the object is destroyed. + */ +class input_file_t { + char const* path_ = nullptr; + std::FILE* file_ = nullptr; + + public: + input_file_t(char const* path) noexcept : path_(path) {} + ~input_file_t() noexcept { close(); } + input_file_t(input_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), file_(exchange(other.file_, nullptr)) {} + input_file_t& operator=(input_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(file_, other.file_); + return *this; + } + + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!file_) + file_ = std::fopen(path_, "rb"); + if (!file_) + return result.failed(std::strerror(errno)); + return result; + } + serialization_result_t read(void* begin, std::size_t length) noexcept { + serialization_result_t result; + std::size_t read = std::fread(begin, length, 1, file_); + if (length && !read) { + bool reached_eof = std::feof(file_); + return result.failed(reached_eof ? "End of file reached!" : std::strerror(errno)); + } + return result; + } + void close() noexcept { + if (file_) + std::fclose(exchange(file_, nullptr)); + } + + explicit operator bool() const noexcept { return file_; } + bool seek_to(std::size_t progress) noexcept { + return std::fseek(file_, static_cast(progress), SEEK_SET) == 0; + } + bool seek_to_end() noexcept { return std::fseek(file_, 0L, SEEK_END) == 0; } + bool infer_progress(std::size_t& progress) noexcept { + long int result = std::ftell(file_); + if (result == -1L) + return false; + progress = static_cast(result); + return true; + } +}; + +/** + * @brief Represents a memory-mapped file or a pre-allocated anonymous memory region. + * + * This class provides a convenient way to memory-map a file and access its contents as a block of + * memory. The class handles platform-specific memory-mapping operations on Windows, Linux, and MacOS. + * The class automatically closes the file when the object is destroyed. + */ +class memory_mapped_file_t { + char const* path_{}; /**< The path to the file to be memory-mapped. */ + void* ptr_{}; /**< A pointer to the memory-mapping. */ + size_t length_{}; /**< The length of the memory-mapped file in bytes. */ + +#if defined(USEARCH_DEFINED_WINDOWS) + HANDLE file_handle_{}; /**< The file handle on Windows. */ + HANDLE mapping_handle_{}; /**< The mapping handle on Windows. */ +#else + int file_descriptor_{}; /**< The file descriptor on Linux and MacOS. */ +#endif + + public: + explicit operator bool() const noexcept { return ptr_ != nullptr; } + byte_t* data() noexcept { return reinterpret_cast(ptr_); } + byte_t const* data() const noexcept { return reinterpret_cast(ptr_); } + std::size_t size() const noexcept { return static_cast(length_); } + + memory_mapped_file_t() noexcept {} + memory_mapped_file_t(char const* path) noexcept : path_(path) {} + ~memory_mapped_file_t() noexcept { close(); } + memory_mapped_file_t(memory_mapped_file_t&& other) noexcept + : path_(exchange(other.path_, nullptr)), ptr_(exchange(other.ptr_, nullptr)), + length_(exchange(other.length_, 0)), +#if defined(USEARCH_DEFINED_WINDOWS) + file_handle_(exchange(other.file_handle_, nullptr)), mapping_handle_(exchange(other.mapping_handle_, nullptr)) +#else + file_descriptor_(exchange(other.file_descriptor_, 0)) +#endif + { + } + + memory_mapped_file_t(memory_mapped_file_t const&) = delete; + memory_mapped_file_t& operator=(memory_mapped_file_t const&) = delete; + + memory_mapped_file_t(byte_t* data, std::size_t length) noexcept : ptr_(data), length_(length) {} + + memory_mapped_file_t& operator=(memory_mapped_file_t&& other) noexcept { + std::swap(path_, other.path_); + std::swap(ptr_, other.ptr_); + std::swap(length_, other.length_); +#if defined(USEARCH_DEFINED_WINDOWS) + std::swap(file_handle_, other.file_handle_); + std::swap(mapping_handle_, other.mapping_handle_); +#else + std::swap(file_descriptor_, other.file_descriptor_); +#endif + return *this; + } + + serialization_result_t open_if_not() noexcept { + serialization_result_t result; + if (!path_ || ptr_) + return result; + +#if defined(USEARCH_DEFINED_WINDOWS) + + HANDLE file_handle = + CreateFile(path_, GENERIC_READ, FILE_SHARE_READ, 0, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL, 0); + if (file_handle == INVALID_HANDLE_VALUE) + return result.failed("Opening file failed!"); + + std::size_t file_length = GetFileSize(file_handle, 0); + HANDLE mapping_handle = CreateFileMapping(file_handle, 0, PAGE_READONLY, 0, 0, 0); + if (mapping_handle == 0) { + CloseHandle(file_handle); + return result.failed("Mapping file failed!"); + } + + byte_t* file = (byte_t*)MapViewOfFile(mapping_handle, FILE_MAP_READ, 0, 0, file_length); + if (file == 0) { + CloseHandle(mapping_handle); + CloseHandle(file_handle); + return result.failed("View the map failed!"); + } + file_handle_ = file_handle; + mapping_handle_ = mapping_handle; + ptr_ = file; + length_ = file_length; +#else + +#if defined(USEARCH_DEFINED_LINUX) + int descriptor = open(path_, O_RDONLY | O_NOATIME); +#else + int descriptor = open(path_, O_RDONLY); +#endif + if (descriptor < 0) + return result.failed(std::strerror(errno)); + + // Estimate the file size + struct stat file_stat; + int fstat_status = fstat(descriptor, &file_stat); + if (fstat_status < 0) { + ::close(descriptor); + return result.failed(std::strerror(errno)); + } + + // Map the entire file + byte_t* file = (byte_t*)mmap(NULL, file_stat.st_size, PROT_READ, MAP_SHARED, descriptor, 0); + if (file == MAP_FAILED) { + ::close(descriptor); + return result.failed(std::strerror(errno)); + } + file_descriptor_ = descriptor; + ptr_ = file; + length_ = file_stat.st_size; +#endif // Platform specific code + return result; + } + + void close() noexcept { + if (!path_) { + ptr_ = nullptr; + length_ = 0; + return; + } +#if defined(USEARCH_DEFINED_WINDOWS) + UnmapViewOfFile(ptr_); + CloseHandle(mapping_handle_); + CloseHandle(file_handle_); + mapping_handle_ = nullptr; + file_handle_ = nullptr; +#else + munmap(ptr_, length_); + ::close(file_descriptor_); + file_descriptor_ = 0; +#endif + ptr_ = nullptr; + length_ = 0; + } +}; + +/** + * @brief Metadata header for the serialized index. + * + * This structure is very minimalistic by design. It contains no information + * about the capacity of the index, so you'll have to `reserve` after loading. + * It also contains no info on the metric or key types, so you'll have to store + * that information elsewhere, like we do in `index_dense_head_t`. + */ +struct index_serialized_header_t { + std::uint64_t size = 0; + std::uint64_t connectivity = 0; + std::uint64_t connectivity_base = 0; + std::uint64_t max_level = 0; + std::uint64_t entry_slot = 0; +}; + +using default_key_t = std::uint64_t; +using default_slot_t = std::uint32_t; +using default_distance_t = float; + +template struct member_gt { + key_at key; + std::size_t slot; +}; + +template inline std::size_t get_slot(member_gt const& m) noexcept { return m.slot; } +template inline key_at get_key(member_gt const& m) noexcept { return m.key; } + +template struct member_cref_gt { + misaligned_ref_gt key; + std::size_t slot; +}; + +template inline std::size_t get_slot(member_cref_gt const& m) noexcept { return m.slot; } +template inline key_at get_key(member_cref_gt const& m) noexcept { return m.key; } + +template struct member_ref_gt { + misaligned_ref_gt key; + std::size_t slot; + + inline operator member_cref_gt() const noexcept { return {key.ptr(), slot}; } +}; + +template inline std::size_t get_slot(member_ref_gt const& m) noexcept { return m.slot; } +template inline key_at get_key(member_ref_gt const& m) noexcept { return m.key; } + +/** + * @brief Approximate Nearest Neighbors Search @b index-structure using the + * Hierarchical Navigable Small World @b (HNSW) graphs algorithm. + * If classical containers store @b Key->Value mappings, this one can + * be seen as a network of keys, accelerating approximate @b Value~>Key visited_members. + * + * Unlike most implementations, this one is generic and can be used for any search, + * not just within equi-dimensional vectors. Examples range from Texts to similar Chess + * positions, Geo-Spatial Search, and even Graphs. + * + * @tparam key_at + * The type of primary objects stored in the index. + * The values, to which those map, are not managed by the same index structure. + * + * @tparam compressed_slot_at + * The smallest unsigned integer type to address indexed elements. + * It is used internally to maximize space-efficiency and is generally + * up-casted to @b `std::size_t` in public interfaces. + * Can be a built-in @b `uint32_t`, `uint64_t`, or our custom @b `uint40_t`. + * Which makes the most sense for 4B+ entry indexes. + * + * @tparam dynamic_allocator_at + * Dynamic memory allocator for temporary buffers, visits indicators, and + * priority queues, needed during construction and traversals of graphs. + * The allocated buffers may be uninitialized. + * + * @tparam tape_allocator_at + * Potentially different memory allocator for primary allocations of nodes and vectors. + * It would never `deallocate` separate entries, and would only free all the space at once. + * The allocated buffers may be uninitialized. + * + * @section Features + * + * - Thread-safe for concurrent construction, search, and updates. + * - Doesn't allocate new threads, and reuses the ones its called from. + * - Allows storing value externally, managing just the similarity index. + * - Joins. + + * @section Usage + * + * @subsection Exceptions + * + * None of the methods throw exceptions in the "Release" compilation mode. + * It may only `throw` if your memory ::dynamic_allocator_at or ::metric_at isn't + * safe to copy. + * + * @subsection Serialization + * + * When serialized, doesn't include any additional metadata. + * It is just the multi-level proximity-graph. You may want to store metadata about + * the used metric and key types somewhere else. + * + * @section Implementation Details + * + * Like every HNSW implementation, USearch builds levels of "Proximity Graphs". + * Every added vector forms a node in one or more levels of the graph. + * Every node is present in the base level. Every following level contains a smaller + * fraction of nodes. During search, the operation starts with the smaller levels + * and zooms-in on every following iteration of larger graph traversals. + * + * Just one memory allocation is performed regardless of the number of levels. + * The adjacency lists across all levels are concatenated into that single buffer. + * That buffer starts with a "head", that stores the metadata, such as the + * tallest "level" of the graph that it belongs to, the external "key", and the + * number of "dimensions" in the vector. + * + * @section Metrics, Predicates and Callbacks + * + * + * @section Smart References and Iterators + * + * - `member_citerator_t` and `member_iterator_t` have only slots, no indirections. + * + * - `member_cref_t` and `member_ref_t` contains the `slot` and a reference + * to the key. So it passes through 1 level of visited_members in `nodes_`. + * Retrieving the key via `get_key` will cause fetching yet another cache line. + * + * - `member_gt` contains an already prefetched copy of the key. + * + */ +template , // + typename tape_allocator_at = dynamic_allocator_at> // +class index_gt { + public: + using distance_t = distance_at; + using vector_key_t = key_at; + using key_t = vector_key_t; + using compressed_slot_t = compressed_slot_at; + using dynamic_allocator_t = dynamic_allocator_at; + using tape_allocator_t = tape_allocator_at; + static_assert(sizeof(vector_key_t) >= sizeof(compressed_slot_t), "Having tiny keys doesn't make sense."); + static_assert(std::is_signed::value, "Distance must be a signed type, as we use the unary minus."); + + using member_cref_t = member_cref_gt; + using member_ref_t = member_ref_gt; + + template class member_iterator_gt { + using ref_t = ref_at; + using index_t = index_at; + + friend class index_gt; + member_iterator_gt() noexcept {} + member_iterator_gt(index_t* index, compressed_slot_t slot) noexcept : index_(index), slot_(slot) {} + + template ref_t call_key(std::true_type) const noexcept { + return ref_t{index_->node_at_(slot_).ckey(), slot_}; + } + template ref_t call_key(std::false_type) const noexcept { + return ref_t{index_->node_at_(slot_).key(), slot_}; + } + + index_t* index_{}; + compressed_slot_t slot_{}; + + public: + using iterator_category = std::random_access_iterator_tag; + using value_type = ref_t; + using difference_type = std::ptrdiff_t; + using pointer = void; + using reference = ref_t; + + reference operator*() const noexcept { return call_key<0>(std::is_const()); } + vector_key_t key() const noexcept { return index_->node_at_(slot_).ckey(); } + + friend inline compressed_slot_t get_slot(member_iterator_gt const& it) noexcept { return it.slot_; } + friend inline vector_key_t get_key(member_iterator_gt const& it) noexcept { return it.key(); } + + // clang-format off + member_iterator_gt operator++(int) noexcept { return member_iterator_gt(index_, static_cast(static_cast(slot_) + 1)); } + member_iterator_gt operator--(int) noexcept { return member_iterator_gt(index_, static_cast(static_cast(slot_) - 1)); } + member_iterator_gt operator+(difference_type d) noexcept { return member_iterator_gt(index_, static_cast(static_cast(slot_) + d)); } + member_iterator_gt operator-(difference_type d) noexcept { return member_iterator_gt(index_, static_cast(static_cast(slot_) - d)); } + member_iterator_gt& operator++() noexcept { slot_ = static_cast(static_cast(slot_) + 1); return *this; } + member_iterator_gt& operator--() noexcept { slot_ = static_cast(static_cast(slot_) - 1); return *this; } + member_iterator_gt& operator+=(difference_type d) noexcept { slot_ = static_cast(static_cast(slot_) + d); return *this; } + member_iterator_gt& operator-=(difference_type d) noexcept { slot_ = static_cast(static_cast(slot_) - d); return *this; } + bool operator==(member_iterator_gt const& other) const noexcept { return index_ == other.index_ && slot_ == other.slot_; } + bool operator!=(member_iterator_gt const& other) const noexcept { return index_ != other.index_ || slot_ != other.slot_; } + // clang-format on + }; + + using member_iterator_t = member_iterator_gt; + using member_citerator_t = member_iterator_gt; + + // STL compatibility: + using value_type = vector_key_t; + using allocator_type = dynamic_allocator_t; + using size_type = std::size_t; + using difference_type = std::ptrdiff_t; + using reference = member_ref_t; + using const_reference = member_cref_t; + using pointer = void; + using const_pointer = void; + using iterator = member_iterator_t; + using const_iterator = member_citerator_t; + using reverse_iterator = std::reverse_iterator; + using reverse_const_iterator = std::reverse_iterator; + + using dynamic_allocator_traits_t = std::allocator_traits; + using byte_t = typename dynamic_allocator_t::value_type; + static_assert( // + sizeof(byte_t) == 1, // + "Primary allocator must allocate separate addressable bytes"); + + using tape_allocator_traits_t = std::allocator_traits; + static_assert( // + sizeof(typename tape_allocator_traits_t::value_type) == 1, // + "Tape allocator must allocate separate addressable bytes"); + + private: + /** + * @brief Integer for the number of node neighbors at a specific level of the + * multi-level graph. It's selected to be `std::uint32_t` to improve the + * alignment in most common cases. + */ + using neighbors_count_t = std::uint32_t; + using level_t = std::int16_t; + + /** + * @brief How many bytes of memory are needed to form the "head" of the node. + */ + static constexpr std::size_t node_head_bytes_() { return sizeof(vector_key_t) + sizeof(level_t); } + + using nodes_mutexes_t = bitset_gt; + + using visits_hash_set_t = growing_hash_set_gt, dynamic_allocator_t>; + + struct precomputed_constants_t { + double inverse_log_connectivity{}; + std::size_t neighbors_bytes{}; + std::size_t neighbors_base_bytes{}; + }; + /// @brief A space-efficient internal data-structure used in graph traversal queues. + struct candidate_t { + distance_t distance; + compressed_slot_t slot; + inline bool operator<(candidate_t other) const noexcept { return distance < other.distance; } + }; + + using candidates_view_t = span_gt; + using candidates_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + using top_candidates_t = sorted_buffer_gt, candidates_allocator_t>; + using next_candidates_t = max_heap_gt, candidates_allocator_t>; + + /** + * @brief A loosely-structured handle for every node. One such node is created for every member. + * To minimize memory usage and maximize the number of entries per cache-line, it only + * stores to pointers. The internal tape starts with a `vector_key_t` @b key, then + * a `level_t` for the number of graph @b levels in which this member appears, + * then the { `neighbors_count_t`, `compressed_slot_t`, `compressed_slot_t` ... } sequences + * for @b each-level. + */ + class node_t { + byte_t* tape_{}; + + public: + explicit node_t(byte_t* tape) noexcept : tape_(tape) {} + byte_t* tape() const noexcept { return tape_; } + byte_t* neighbors_tape() const noexcept { return tape_ + node_head_bytes_(); } + explicit operator bool() const noexcept { return tape_; } + + node_t() = default; + node_t(node_t const&) = default; + node_t& operator=(node_t const&) = default; + + misaligned_ref_gt ckey() const noexcept { return {tape_}; } + misaligned_ref_gt ckey() noexcept { return {tape_}; } + misaligned_ref_gt key() const noexcept { return {tape_}; } + misaligned_ref_gt key() noexcept { return {tape_}; } + misaligned_ref_gt level() noexcept { return {tape_ + sizeof(vector_key_t)}; } + + void key(vector_key_t v) noexcept { return misaligned_store(tape_, v); } + void level(level_t v) noexcept { return misaligned_store(tape_ + sizeof(vector_key_t), v); } + }; + + static_assert(std::is_trivially_copy_constructible::value, "Nodes must be light!"); + static_assert(std::is_trivially_destructible::value, "Nodes must be light!"); + + /** + * @brief A slice of the node's tape, containing a the list of neighbors + * for a node in a single graph level. It's pre-allocated to fit + * as many neighbors "slots", as may be needed at the target level, + * and starts with a single integer `neighbors_count_t` counter. + */ + class neighbors_ref_t { + byte_t* tape_; + + static constexpr std::size_t shift(std::size_t i = 0) noexcept { + return sizeof(neighbors_count_t) + sizeof(compressed_slot_t) * i; + } + + public: + using iterator = misaligned_ptr_gt; + using const_iterator = misaligned_ptr_gt; + using value_type = compressed_slot_t; + + neighbors_ref_t(byte_t* tape) noexcept : tape_(tape) {} + misaligned_ptr_gt begin() noexcept { return tape_ + shift(); } + misaligned_ptr_gt end() noexcept { return begin() + size(); } + misaligned_ptr_gt begin() const noexcept { return tape_ + shift(); } + misaligned_ptr_gt end() const noexcept { return begin() + size(); } + misaligned_ptr_gt cbegin() noexcept { return tape_ + shift(); } + misaligned_ptr_gt cend() noexcept { return cbegin() + size(); } + compressed_slot_t operator[](std::size_t i) const noexcept { + return misaligned_load(tape_ + shift(i)); + } + std::size_t size() const noexcept { return misaligned_load(tape_); } + void clear() noexcept { + neighbors_count_t n = misaligned_load(tape_); + std::memset(tape_, 0, shift(n)); + misaligned_store(tape_, 0); + } + void push_back(compressed_slot_t slot) noexcept { + neighbors_count_t n = misaligned_load(tape_); + misaligned_store(tape_ + shift(n), slot); + misaligned_store(tape_, n + 1); + } + template std::size_t erase_if(allow_slot_at&& allow_slot) noexcept { + std::size_t old_count = misaligned_load(tape_); + std::size_t removed_count = 0; + for (std::size_t i = 0; i < old_count; ++i) { + compressed_slot_t slot = misaligned_load(tape_ + shift(i)); + if (allow_slot(slot)) { + removed_count++; + } else { + misaligned_store(tape_ + shift(i - removed_count), slot); + } + } + misaligned_store(tape_, old_count - removed_count); + return removed_count; + } + }; + + /** + * @brief A package of all kinds of temporary data-structures, that the threads + * would reuse to process requests. Similar to having all of those as + * separate `thread_local` global variables. + */ + struct usearch_align_m context_t { + top_candidates_t top_candidates{}; + top_candidates_t top_for_refine{}; + next_candidates_t next_candidates{}; + visits_hash_set_t visits{}; + std::default_random_engine level_generator{}; + std::size_t iteration_cycles{}; + std::size_t computed_distances{}; + std::size_t computed_distances_in_refines{}; + std::size_t computed_distances_in_reverse_refines{}; + + /// @brief Heterogeneous distance calculation. + template // + inline distance_t measure(value_at const& first, entry_at const& second, metric_at&& metric) noexcept { + static_assert( // + std::is_same::value || std::is_same::value, + "Unexpected type"); + + computed_distances++; + return metric(first, second); + } + + /// @brief Homogeneous distance calculation. + template // + inline distance_t measure(entry_at const& first, entry_at const& second, metric_at&& metric) noexcept { + static_assert( // + std::is_same::value || std::is_same::value, + "Unexpected type"); + + computed_distances++; + return metric(first, second); + } + + /// @brief Heterogeneous batch distance calculation. + template // + inline void measure_batch(value_at const& first, entries_at const& second_entries, metric_at&& metric, + candidate_allowed_at&& candidate_allowed, transform_at&& transform, + callback_at&& callback) noexcept { + + using entry_t = typename std::remove_reference::type; + metric.batch(first, second_entries, candidate_allowed, transform, + [&](entry_t const& entry, distance_t distance) { + callback(entry, distance); + computed_distances++; + }); + } + }; + + /// @brief Number of "slots" available for `node_t` objects. Equals to @b `limits_.members`. + mutable std::atomic nodes_capacity_{}; + + /// @brief Number of "slots" already storing non-null nodes. + mutable std::atomic nodes_count_{}; + + index_config_t config_{}; + index_limits_t limits_{}; + + mutable dynamic_allocator_t dynamic_allocator_{}; + tape_allocator_t tape_allocator_{}; + + precomputed_constants_t pre_{}; + memory_mapped_file_t viewed_file_{}; + + /// @brief Controls access to `max_level_` and `entry_slot_`. + /// If any thread is updating those values, no other threads can `add()` or `search()`. + std::mutex global_mutex_{}; + + /// @brief The level of the top-most graph in the index. Grows as the logarithm of size, starts from zero. + level_t max_level_{}; + + /// @brief The slot in which the only node of the top-level graph is stored. + std::size_t entry_slot_{}; + + using nodes_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + /// @brief C-style array of `node_t` smart-pointers. Use `compressed_slot_t` for indexing. + buffer_gt nodes_{}; + + /// @brief Mutex, that limits concurrent access to `nodes_`. + mutable nodes_mutexes_t nodes_mutexes_{}; + + using contexts_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + /// @brief Array of thread-specific buffers for temporary data. + mutable buffer_gt contexts_{}; + + public: + std::size_t connectivity() const noexcept { return config_.connectivity; } + std::size_t capacity() const noexcept { return nodes_capacity_; } + std::size_t size() const noexcept { return nodes_count_; } + std::size_t max_level() const noexcept { return nodes_count_ ? static_cast(max_level_) : 0; } + index_config_t const& config() const noexcept { return config_; } + index_limits_t const& limits() const noexcept { return limits_; } + bool is_immutable() const noexcept { return bool(viewed_file_); } + explicit operator bool() const noexcept { return config_.is_valid(); } + + /** + * @brief Default index constructor, suitable only for stateless allocators. + * @warning Consider `index_gt::make` instead, or explicitly convert to `bool` to check if the index is valid. + * @section Exceptions + * Doesn't throw, unless the ::dynamic_allocator's and ::tape_allocator's throw on move-construction. + */ + explicit index_gt( // + dynamic_allocator_t dynamic_allocator = {}, tape_allocator_t tape_allocator = {}) noexcept(false) + : nodes_capacity_(0u), nodes_count_(0u), config_(), limits_(0, 0), + dynamic_allocator_(std::move(dynamic_allocator)), tape_allocator_(std::move(tape_allocator)), + pre_(precompute_({})), max_level_(-1), entry_slot_(0u), nodes_(), nodes_mutexes_(), contexts_() {} + + /** + * @brief Default index constructor, suitable only for stateless allocators. + * @warning Consider `index_gt::make` instead, or explicitly convert to `bool` to check if the index is valid. + * @section Exceptions + * Doesn't throw, unless the ::dynamic_allocator's and ::tape_allocator's throw on move-construction. + */ + explicit index_gt(index_config_t config, dynamic_allocator_t dynamic_allocator = {}, + tape_allocator_t tape_allocator = {}) noexcept(false) + : index_gt(dynamic_allocator, tape_allocator) { + config.validate(); + config_ = config; + pre_ = precompute_(config); + } + + /** + * @brief Clones the structure with the same hyper-parameters, but without contents. + */ + index_gt fork() noexcept { return index_gt{config_, dynamic_allocator_, tape_allocator_}; } + + ~index_gt() noexcept { reset(); } + + index_gt(index_gt&& other) noexcept { swap(other); } + + index_gt& operator=(index_gt&& other) noexcept { + swap(other); + return *this; + } + + struct state_result_t { + index_gt index; + error_t error; + + explicit operator bool() const noexcept { return !error; } + state_result_t failed(error_t message) noexcept { return {std::move(index), std::move(message)}; } + operator index_gt&&() && { + if (error) + usearch_raise_runtime_error(error.what()); + return std::move(index); + } + }; + using copy_result_t = state_result_t; + + /** + * @brief The recommended way to initialize the index, as unlike the constructor, + * it can fail with an error message, without raising an exception. + * + * @param[in] config The configuration specs of the index. + * @param[in] dynamic_allocator The allocator for temporary buffers and thread contexts, like priority queues. + * @param[in] tape_allocator The allocator for the primary allocations of nodes and vectors. + */ + static state_result_t make( // + index_config_t config = {}, dynamic_allocator_t dynamic_allocator = {}, + tape_allocator_t tape_allocator = {}) noexcept { + + state_result_t result; + result.error = config.validate(); + if (result.error) + return result; + + index_gt index; + index.config_ = std::move(config); + index.dynamic_allocator_ = std::move(dynamic_allocator); + index.tape_allocator_ = std::move(tape_allocator); + index.pre_ = precompute_(index.config_); + index.nodes_count_ = 0u; + index.max_level_ = -1; + index.entry_slot_ = 0u; + + result.index = std::move(index); + return result; + } + + /** + * @brief The recommended way to copy the index, as unlike the copy-constructor, + * it can fail with an error message, without raising an exception. + * + * @param[in] config The configuration specs for the copy-operation. Currently unused. + */ + copy_result_t copy(index_copy_config_t config = {}) const noexcept { + copy_result_t result; + index_gt& other = result.index; + other = index_gt(config_, dynamic_allocator_, tape_allocator_); + if (!other.reserve(limits_)) + return result.failed("Failed to reserve the contexts"); + + // Now all is left - is to allocate new `node_t` instances and populate + // the `other.nodes_` array into it. + for (std::size_t i = 0; i != nodes_count_; ++i) + other.nodes_[i] = other.node_make_copy_(node_bytes_(nodes_[i])); + + other.nodes_count_ = nodes_count_.load(); + other.max_level_ = max_level_; + other.entry_slot_ = entry_slot_; + + // This controls nothing for now :) + (void)config; + return result; + } + + member_citerator_t cbegin() const noexcept { return {this, static_cast(0u)}; } + member_citerator_t cend() const noexcept { return {this, static_cast(size())}; } + member_citerator_t begin() const noexcept { return {this, static_cast(0u)}; } + member_citerator_t end() const noexcept { return {this, static_cast(size())}; } + member_iterator_t begin() noexcept { return {this, static_cast(0u)}; } + member_iterator_t end() noexcept { return {this, static_cast(size())}; } + + member_ref_t at(compressed_slot_t slot) noexcept { return {nodes_[slot].key(), slot}; } + member_cref_t at(compressed_slot_t slot) const noexcept { return {nodes_[slot].ckey(), slot}; } + member_iterator_t iterator_at(compressed_slot_t slot) noexcept { return {this, slot}; } + member_citerator_t citerator_at(compressed_slot_t slot) const noexcept { return {this, slot}; } + + dynamic_allocator_t const& dynamic_allocator() const noexcept { return dynamic_allocator_; } + tape_allocator_t const& tape_allocator() const noexcept { return tape_allocator_; } + +#if defined(USEARCH_USE_PRAGMA_REGION) +#pragma region Adjusting Configuration +#endif + + /** + * @brief Erases all the vectors from the index. + * + * Will change `size()` to zero, but will keep the same `capacity()`. + * Will keep the number of available threads/contexts the same as it was. + */ + void clear() noexcept { + if (!has_reset()) { + std::size_t n = nodes_count_; + for (std::size_t i = 0; i != n; ++i) + node_free_(i); + } else + tape_allocator_.deallocate(nullptr, 0); + nodes_count_ = 0; + max_level_ = -1; + entry_slot_ = 0u; + } + + /** + * @brief Erases all members from index, closing files, and returning RAM to OS. + * + * Will change both `size()` and `capacity()` to zero. + * Will deallocate all threads/contexts. + * If the index is memory-mapped - releases the mapping and the descriptor. + */ + void reset() noexcept { + clear(); + + nodes_ = {}; + contexts_ = {}; + nodes_mutexes_ = {}; + limits_ = index_limits_t{0, 0}; + nodes_capacity_ = 0; + viewed_file_ = memory_mapped_file_t{}; + tape_allocator_ = {}; + } + + /** + * @brief Swaps the underlying memory buffers and thread contexts. + */ + void swap(index_gt& other) noexcept { + std::swap(config_, other.config_); + std::swap(limits_, other.limits_); + std::swap(dynamic_allocator_, other.dynamic_allocator_); + std::swap(tape_allocator_, other.tape_allocator_); + std::swap(pre_, other.pre_); + std::swap(viewed_file_, other.viewed_file_); + std::swap(max_level_, other.max_level_); + std::swap(entry_slot_, other.entry_slot_); + std::swap(nodes_, other.nodes_); + std::swap(nodes_mutexes_, other.nodes_mutexes_); + std::swap(contexts_, other.contexts_); + + // Non-atomic parts. + std::size_t capacity_copy = nodes_capacity_; + std::size_t count_copy = nodes_count_; + nodes_capacity_ = other.nodes_capacity_.load(); + nodes_count_ = other.nodes_count_.load(); + other.nodes_capacity_ = capacity_copy; + other.nodes_count_ = count_copy; + } + + /** + * @brief Increases the `capacity()` of the index to allow adding more vectors. + * @return `true` on success, `false` on memory allocation errors. + */ + bool try_reserve(index_limits_t limits) usearch_noexcept_m { + + if (limits.threads_add <= limits_.threads_add // + && limits.threads_search <= limits_.threads_search // + && limits.members <= limits_.members) + return true; + + // In some cases, we don't want to update the number of members, + // just want to make sure that future reserves use the new thread limits. + if (!limits.members && !size()) { + limits_ = limits; + return true; + } + + nodes_mutexes_t new_mutexes(limits.members); + buffer_gt new_nodes(limits.members); + buffer_gt new_contexts(limits.threads()); + if (!new_nodes || !new_contexts || !new_mutexes) + return false; + + // Move the nodes info, and deallocate previous buffers. + if (nodes_) + std::memcpy(new_nodes.data(), nodes_.data(), sizeof(node_t) * size()); + + // Pre-reserve the capacity for `top_for_refine`, which always contains at most one more + // element than the connectivity factors. + std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); + for (std::size_t i = 0; i != new_contexts.size(); ++i) + if (!new_contexts[i].top_for_refine.reserve(connectivity_max + 1)) + return false; + + limits_ = limits; + nodes_capacity_ = limits.members; + nodes_ = std::move(new_nodes); + contexts_ = std::move(new_contexts); + nodes_mutexes_ = std::move(new_mutexes); + return true; + } + + /** + * @brief Increases the `capacity()` of the index to allow adding more vectors. + * @warning Unlike STL, won't throw exceptions on memory allocations, so check the return value. + * @return `true` on success, `false` on memory allocation errors. + */ + bool reserve(index_limits_t limits) usearch_noexcept_m { return try_reserve(limits); } + +#if defined(USEARCH_USE_PRAGMA_REGION) +#pragma endregion + +#pragma region Construction and Search +#endif + + struct add_result_t { + error_t error{}; + std::size_t new_size{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + std::size_t computed_distances_in_refines{}; + std::size_t computed_distances_in_reverse_refines{}; + compressed_slot_t slot{}; + + explicit operator bool() const noexcept { return !error; } + add_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /// @brief Describes a matched search result, augmenting `member_cref_t` + /// contents with `distance` to the query object. + struct match_t { + member_cref_t member; + distance_t distance; + + inline match_t() noexcept : member({nullptr, 0}), distance(std::numeric_limits::max()) {} + + inline match_t(member_cref_t member, distance_t distance) noexcept : member(member), distance(distance) {} + + inline match_t(match_t&& other) noexcept + : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) {} + + inline match_t(match_t const& other) noexcept + : member({other.member.key.ptr(), other.member.slot}), distance(other.distance) {} + + inline match_t& operator=(match_t const& other) noexcept { + member.key.reset(other.member.key.ptr()); + member.slot = other.member.slot; + distance = other.distance; + return *this; + } + + inline match_t& operator=(match_t&& other) noexcept { + member.key.reset(other.member.key.ptr()); + member.slot = other.member.slot; + distance = other.distance; + return *this; + } + }; + + class search_result_t { + node_t const* nodes_{}; + top_candidates_t const* top_{}; + + friend class index_gt; + inline search_result_t(index_gt const& index, top_candidates_t const* top) noexcept + : nodes_(index.nodes_), top_(top) {} + + public: + /** @brief Number of search results found. */ + std::size_t count{}; + /** @brief Number of graph nodes traversed. */ + std::size_t visited_members{}; + /** @brief Number of times the distances were computed. */ + std::size_t computed_distances{}; + error_t error{}; + + inline search_result_t() noexcept {} + inline search_result_t(search_result_t&&) = default; + inline search_result_t& operator=(search_result_t&&) = default; + + explicit operator bool() const noexcept { return !error; } + search_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + + inline operator std::size_t() const noexcept { return count; } + inline std::size_t size() const noexcept { return count; } + inline bool empty() const noexcept { return !count; } + inline match_t operator[](std::size_t i) const noexcept { return at(i); } + inline match_t front() const noexcept { return at(0); } + inline match_t back() const noexcept { return at(count - 1); } + inline bool contains(vector_key_t key) const noexcept { + for (std::size_t i = 0; i != count; ++i) + if (at(i).member.key == key) + return true; + return false; + } + inline match_t at(std::size_t i) const noexcept { + candidate_t const* top_ordered = top_->data(); + candidate_t candidate = top_ordered[i]; + node_t node = nodes_[candidate.slot]; + return {member_cref_t{node.ckey(), candidate.slot}, candidate.distance}; + } + + /** + * @brief Extracts the search results into a user-provided buffer, that unlike `dump_to`, + * may already contain some data, so the new and old results are merged together. + * @return The number of results stored in the buffer. + * @param[in] keys The buffer to store the keys of the search results. + * @param[in] distances The buffer to store the distances to the search results. + * @param[in] old_count The number of results already stored in the buffers. + * @param[in] max_count The maximum number of results that can be stored in the buffers. + */ + inline std::size_t merge_into( // + vector_key_t* keys, distance_t* distances, // + std::size_t old_count, std::size_t max_count) const noexcept { + + std::size_t merged_count = old_count; + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + distance_t* merged_end = distances + merged_count; + std::size_t offset = std::lower_bound(distances, merged_end, result.distance) - distances; + if (offset == max_count) + continue; + + std::size_t count_worse = merged_count - offset - (max_count == merged_count); + std::memmove(keys + offset + 1, keys + offset, count_worse * sizeof(vector_key_t)); + std::memmove(distances + offset + 1, distances + offset, count_worse * sizeof(distance_t)); + keys[offset] = result.member.key; + distances[offset] = result.distance; + merged_count += merged_count != max_count; + } + return merged_count; + } + + /** + * @brief Extracts the search results into a user-provided buffer. + * @return The number of results stored in the buffer. + * @param[in] keys The buffer to store the keys of the search results. + * @param[in] distances The buffer to store the distances to the search results. + */ + inline std::size_t dump_to(vector_key_t* keys, distance_t* distances) const noexcept { + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + keys[i] = result.member.key; + distances[i] = result.distance; + } + return count; + } + + /** + * @brief Extracts the search results into a user-provided buffer. + * @return The number of results stored in the buffer. + * @param[in] keys The buffer to store the keys of the search results. + */ + inline std::size_t dump_to(vector_key_t* keys) const noexcept { + for (std::size_t i = 0; i != count; ++i) { + match_t result = operator[](i); + keys[i] = result.member.key; + } + return count; + } + + /** + * @brief Extracts the search results into a user-provided buffer. + * @return The number of results stored in the buffer. + * @param[in] keys The buffer to store the keys of the search results. + * @param[in] distances The buffer to store the distances to the search results. + * @param[in] capacity The maximum number of results that can be stored in the buffers. + */ + inline std::size_t dump_to(vector_key_t* keys, distance_t* distances, std::size_t capacity) const noexcept { + std::size_t i = 0; + std::size_t initialized_count = (std::min)(count, capacity); + for (; i != initialized_count; ++i) { + match_t result = operator[](i); + keys[i] = result.member.key; + distances[i] = result.distance; + } + for (; i != capacity; ++i) { + keys[i] = vector_key_t{}; + distances[i] = std::numeric_limits::has_signaling_NaN + ? std::numeric_limits::signaling_NaN() + : std::numeric_limits::max(); + } + return initialized_count; + } + + /** + * @brief Extracts the search results into a user-provided buffer. + * @return The number of results stored in the buffer. + * @param[in] keys The buffer to store the keys of the search results. + * @param[in] capacity The maximum number of results that can be stored in the buffers. + */ + inline std::size_t dump_to(vector_key_t* keys, std::size_t capacity) const noexcept { + std::size_t i = 0; + std::size_t initialized_count = (std::min)(this->count, capacity); + for (; i != initialized_count; ++i) { + match_t result = operator[](i); + keys[i] = result.member.key; + } + for (; i != capacity; ++i) + keys[i] = vector_key_t{}; + + return initialized_count; + } + }; + + struct cluster_result_t { + error_t error{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + match_t cluster{}; + + explicit operator bool() const noexcept { return !error; } + cluster_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Inserts a new entry into the index. Thread-safe. Supports @b heterogeneous lookups. + * Expects needed capacity to be reserved ahead of time: `size() < capacity()`. + * + * @tparam metric_at + * A function responsible for computing the distance @b (dis-similarity) between two objects. + * It should be callable into distinctly different scenarios: + * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. + * - `distance_t operator() (entry_at, entry_at)` - between existing entries. + * Where any possible `entry_at` has both two interfaces: `std::size_t slot()`, `vector_key_t key()`. + * + * @param[in] key External identifier/name/descriptor for the new entry. + * @param[in] value Content that will be compared against other entries to index. + * @param[in] metric Callable object measuring distance between ::value and present objects. + * @param[in] config Configuration options for this specific operation. + * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. + */ + template < // + typename value_at, // + typename metric_at, // + typename callback_at = dummy_callback_t, // + typename prefetch_at = dummy_prefetch_t // + > + add_result_t add( // + vector_key_t key, value_at&& value, metric_at&& metric, // + index_update_config_t config = {}, // + callback_at&& callback = callback_at{}, // + prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { + + add_result_t result; + if (is_immutable()) + return result.failed("Can't add to an immutable index"); + + // Make sure we have enough local memory to perform this request + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + next_candidates_t& next = context.next_candidates; + top.clear(); + next.clear(); + + // The top list needs one more slot than the connectivity of the base level + // for the heuristic, that tries to squeeze one more element into saturated list. + std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); + std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); + if (!top.reserve(top_limit)) + return result.failed("Out of memory!"); + if (!next.reserve(config.expansion)) + return result.failed("Out of memory!"); + + // Determining how much memory to allocate for the node depends on the target level + std::unique_lock new_level_lock(global_mutex_); + level_t max_level_copy = max_level_; // Copy under lock + compressed_slot_t entry_slot_copy = static_cast(entry_slot_); // Copy under lock + level_t new_target_level = choose_random_level_(context.level_generator); + + // Make sure we are not overflowing + std::size_t capacity = nodes_capacity_.load(); + std::size_t old_size = nodes_count_.fetch_add(1); + if (old_size >= capacity) { + nodes_count_.fetch_sub(1); + return result.failed("Reserve capacity ahead of insertions!"); + } + + // Allocate the neighbors + node_t new_node = node_make_(key, new_target_level); + if (!new_node) { + nodes_count_.fetch_sub(1); + return result.failed("Out of memory!"); + } + if (new_target_level <= max_level_copy) + new_level_lock.unlock(); + + nodes_[old_size] = new_node; + result.new_size = old_size + 1; + compressed_slot_t new_slot = result.slot = static_cast(old_size); + callback(at(result.slot)); + + // Do nothing for the first element + if (!old_size) { + entry_slot_ = result.slot; + max_level_ = new_target_level; + return result; + } + + // Pull stats + result.computed_distances = context.computed_distances; + result.computed_distances_in_refines = context.computed_distances_in_refines; + result.computed_distances_in_reverse_refines = context.computed_distances_in_reverse_refines; + result.visited_members = context.iteration_cycles; + + // Go down the level, tracking only the closest match + compressed_slot_t closest_slot = search_for_one_( // + value, metric, prefetch, // + entry_slot_copy, max_level_copy, new_target_level, context); + + // From `new_target_level` down - perform proper extensive search + for (level_t level = (std::min)(new_target_level, max_level_copy); level >= 0; --level) { + // TODO: Handle out of memory conditions + search_to_insert_(value, metric, prefetch, closest_slot, level, config.expansion, context); + candidates_view_t closest_view; + { + node_lock_t new_lock = node_lock_(new_slot); + neighbors_(new_node, level).clear(); + closest_view = form_links_to_closest_(metric, new_slot, level, context); + closest_slot = closest_view[0].slot; + } + form_reverse_links_(metric, new_slot, closest_view, value, level, context); + } + + // Normalize stats + result.computed_distances = context.computed_distances - result.computed_distances; + result.computed_distances_in_refines = + context.computed_distances_in_refines - result.computed_distances_in_refines; + result.computed_distances_in_reverse_refines = + context.computed_distances_in_reverse_refines - result.computed_distances_in_reverse_refines; + result.visited_members = context.iteration_cycles - result.visited_members; + + // Updating the entry point if needed + if (new_target_level > max_level_copy) { + entry_slot_ = new_slot; + max_level_ = new_target_level; + } + return result; + } + + /** + * @brief Update an existing entry. Thread-safe. Supports @b heterogeneous lookups. + * + * ! It's assumed that different threads aren't updating the same entry at the same time. + * ! The state won't be corrupted, but no transactional guarantees are provided and the + * ! resulting value & neighbors list may be inconsistent. + * + * @tparam metric_at + * A function responsible for computing the distance @b (dis-similarity) between two objects. + * It should be callable into distinctly different scenarios: + * - `distance_t operator() (value_at, entry_at)` - from new object to existing entries. + * - `distance_t operator() (entry_at, entry_at)` - between existing entries. + * For any possible `entry_at` following interfaces will work: + * - `std::size_t get_slot(entry_at const &)` + * - `vector_key_t get_key(entry_at const &)` + * + * @param[in] iterator Iterator pointing to an existing entry to be replaced. + * @param[in] key External identifier/name/descriptor for the entry. + * @param[in] value Content that will be compared against other entries in the index. + * @param[in] metric Callable object measuring distance between ::value and present objects. + * @param[in] config Configuration options for this specific operation. + * @param[in] callback On-success callback, executed while the `member_ref_t` is still under lock. + */ + template < // + typename value_at, // + typename metric_at, // + typename callback_at = dummy_callback_t, // + typename prefetch_at = dummy_prefetch_t // + > + add_result_t update( // + member_iterator_t iterator, // + vector_key_t key, // + value_at&& value, // + metric_at&& metric, // + index_update_config_t config = {}, // + callback_at&& callback = callback_at{}, // + prefetch_at&& prefetch = prefetch_at{}) usearch_noexcept_m { + + // Someone is gonna fuzz this, so let's make sure we cover the basics + if (!config.expansion) + config.expansion = default_expansion_add(); + + usearch_assert_m(!is_immutable(), "Can't add to an immutable index"); + add_result_t result; + compressed_slot_t updated_slot = iterator.slot_; + + // Make sure we have enough local memory to perform this request + context_t& context = contexts_[config.thread]; + top_candidates_t& top = context.top_candidates; + next_candidates_t& next = context.next_candidates; + top.clear(); + next.clear(); + + // The top list needs one more slot than the connectivity of the base level + // for the heuristic, that tries to squeeze one more element into saturated list. + std::size_t connectivity_max = (std::max)(config_.connectivity_base, config_.connectivity); + std::size_t top_limit = (std::max)(connectivity_max + 1, config.expansion); + if (!top.reserve(top_limit)) + return result.failed("Out of memory!"); + if (!next.reserve(config.expansion)) + return result.failed("Out of memory!"); + + node_t updated_node = node_at_(updated_slot); + level_t updated_node_level = updated_node.level(); + + // Copy entry coordinates under locks + level_t max_level_copy; + compressed_slot_t entry_slot_copy; + { + std::unique_lock new_level_lock(global_mutex_); + max_level_copy = max_level_; // Copy under lock + entry_slot_copy = static_cast(entry_slot_); // Copy under lock + } + + // Pull stats + result.computed_distances = context.computed_distances; + result.visited_members = context.iteration_cycles; + + // Go down the level, tracking only the closest match; + // It may even be equal to the `updated_slot` + compressed_slot_t closest_slot = + // If we are updating the entry node itself, it won't contain any neighbors, + // so we should traverse a level down to find the closest match. + updated_node_level == max_level_copy // + ? entry_slot_copy + : search_for_one_( // + value, metric, prefetch, // + entry_slot_copy, max_level_copy, updated_node_level, context); + + // From `updated_node_level` down - perform proper extensive search + for (level_t level = (std::min)(updated_node_level, max_level_copy); level >= 0; --level) { + if (!search_to_update_(value, metric, prefetch, closest_slot, updated_slot, level, config.expansion, + context)) + return result.failed("Out of memory!"); + + candidates_view_t closest_view; + { + node_lock_t updated_lock = node_lock_(updated_slot); + // TODO: Go through existing neighbors removing reverse links + // for (compressed_slot_t slot : neighbors_(updated_node, level)) + // remove_link_(slot, updated_slot, level); + neighbors_(updated_node, level).clear(); + closest_view = form_links_to_closest_(metric, updated_slot, level, context); + if (closest_view.size()) + closest_slot = closest_view[0].slot; + } + form_reverse_links_(metric, updated_slot, closest_view, value, level, context); + } + updated_node.key(key); + + // Normalize stats + result.computed_distances = context.computed_distances - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + result.slot = updated_slot; + + callback(at(updated_slot)); + return result; + } + + /** + * @brief Searches for the closest elements to the given ::query. Thread-safe. + * + * @param[in] query Content that will be compared against other entries in the index. + * @param[in] wanted The upper bound for the number of results to return. + * @param[in] config Configuration options for this specific operation. + * @param[in] predicate Optional filtering predicate for `member_cref_t`. + * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. + */ + template < // + typename value_at, // + typename metric_at, // + typename predicate_at = dummy_predicate_t, // + typename prefetch_at = dummy_prefetch_t // + > + search_result_t search( // + value_at&& query, // + std::size_t wanted, // + metric_at&& metric, // + index_search_config_t config = {}, // + predicate_at&& predicate = predicate_at{}, // + prefetch_at&& prefetch = prefetch_at{}) const usearch_noexcept_m { + + // Someone is gonna fuzz this, so let's make sure we cover the basics + if (!wanted) + return search_result_t{}; + + // Expansion factor set to zero is equivalent to the default value + if (!config.expansion) + config.expansion = default_expansion_search(); + + // Using references is cleaner, but would result in UBSan false positives + context_t* context_ptr = contexts_.data() ? contexts_.data() + config.thread : nullptr; + top_candidates_t* top_ptr = context_ptr ? &context_ptr->top_candidates : nullptr; + search_result_t result{*this, top_ptr}; + if (!nodes_count_.load(std::memory_order_relaxed)) + return result; + + usearch_assert_m(contexts_.size() > config.thread, "Thread index out of bounds"); + context_t& context = *context_ptr; + top_candidates_t& top = *top_ptr; + // Go down the level, tracking only the closest match + result.computed_distances = context.computed_distances; + result.visited_members = context.iteration_cycles; + + if (config.exact) { + if (!top.reserve(wanted)) + return result.failed("Out of memory!"); + search_exact_(query, metric, predicate, wanted, context); + } else { + next_candidates_t& next = context.next_candidates; + std::size_t expansion = (std::max)(config.expansion, wanted); + usearch_assert_m(expansion > 0, "Expansion factor can't be a zero!"); + if (!next.reserve(expansion)) + return result.failed("Out of memory!"); + if (!top.reserve(expansion)) + return result.failed("Out of memory!"); + + compressed_slot_t closest_slot = search_for_one_( + query, metric, prefetch, static_cast(entry_slot_), max_level_, 0, context); + + // For bottom layer we need a more optimized procedure + if (!search_to_find_in_base_(query, metric, predicate, prefetch, closest_slot, expansion, context)) + return result.failed("Out of memory!"); + } + + top.sort_ascending(); + top.shrink(wanted); + + // Normalize stats + result.computed_distances = context.computed_distances - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + result.count = top.size(); + return result; + } + + /** + * @brief Identifies the closest cluster to the given ::query. Thread-safe. + * + * @param[in] query Content that will be compared against other entries in the index. + * @param[in] level The index level to target. Higher means lower resolution. + * @param[in] config Configuration options for this specific operation. + * @param[in] predicate Optional filtering predicate for `member_cref_t`. + * @return Smart object referencing temporary memory. Valid until next `search()`, `add()`, or `cluster()`. + */ + template < // + typename value_at, // + typename metric_at, // + typename predicate_at = dummy_predicate_t, // + typename prefetch_at = dummy_prefetch_t // + > + cluster_result_t cluster( // + value_at&& query, // + std::size_t level, // + metric_at&& metric, // + index_cluster_config_t config = {}, // + predicate_at&& predicate = predicate_at{}, // + prefetch_at&& prefetch = prefetch_at{}) const noexcept { + + context_t& context = contexts_[config.thread]; + cluster_result_t result; + if (!nodes_count_) + return result.failed("No clusters to identify"); + + // Go down the level, tracking only the closest match + result.computed_distances = context.computed_distances; + result.visited_members = context.iteration_cycles; + + next_candidates_t& next = context.next_candidates; + std::size_t expansion = config.expansion; + if (!next.reserve(expansion)) + return result.failed("Out of memory!"); + + result.cluster.member = + at(search_for_one_(query, metric, prefetch, static_cast(entry_slot_), max_level_, + static_cast(level <= 0 ? 0 : level - 1), context)); + result.cluster.distance = context.measure(query, result.cluster.member, metric); + + // Normalize stats + result.computed_distances = context.computed_distances - result.computed_distances; + result.visited_members = context.iteration_cycles - result.visited_members; + + (void)predicate; + return result; + } + +#if defined(USEARCH_USE_PRAGMA_REGION) +#pragma endregion + +#pragma region Metadata +#endif + + struct stats_t { + std::size_t nodes{}; + std::size_t edges{}; + std::size_t max_edges{}; + std::size_t allocated_bytes{}; + }; + + /** + * @brief Aggregates stats on the number of nodes, edges, and memory usage across all levels. + */ + stats_t stats() const noexcept { + stats_t result{}; + + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + std::size_t max_edges = node.level() * config_.connectivity + config_.connectivity_base; + std::size_t edges = 0; + for (level_t level = 0; level <= node.level(); ++level) + edges += neighbors_(node, level).size(); + + ++result.nodes; + result.allocated_bytes += node_bytes_(node).size(); + result.edges += edges; + result.max_edges += max_edges; + } + return result; + } + + /** + * @brief Aggregates stats on the number of nodes, edges, and memory usage up to a specific level. + * + * The `level` parameter is zero-based, where `0` is the base level. + * For example, `level=1` will include the base level and the first level of connections. + */ + stats_t stats(std::size_t level) const noexcept { + stats_t result{}; + std::size_t neighbors_bytes = !level ? pre_.neighbors_base_bytes : pre_.neighbors_bytes; + std::size_t max_edges_per_node = !level ? config_.connectivity_base : config_.connectivity; + + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + if (static_cast(node.level()) < level) + continue; + + ++result.nodes; + result.edges += neighbors_(node, level).size(); + result.allocated_bytes += node_head_bytes_() + neighbors_bytes; + } + + result.max_edges = result.nodes * max_edges_per_node; + return result; + } + + /** + * @brief Aggregates stats on the number of nodes, edges, and memory usage up to a specific level, + * simultaneously exporting the stats for each level into the `stats_per_level` C-style array. + * + * The `max_level` parameter is zero-based, where `0` is the base level. + * For example, `max_level=1` will include the base level and the first level of connections. + */ + stats_t stats(stats_t* stats_per_level, std::size_t max_level) const noexcept { + + std::size_t head_bytes = node_head_bytes_(); + for (std::size_t i = 0; i != size(); ++i) { + node_t node = node_at_(i); + + stats_per_level[0].nodes++; + stats_per_level[0].edges += neighbors_(node, 0).size(); + stats_per_level[0].allocated_bytes += pre_.neighbors_base_bytes + head_bytes; + + level_t node_level = static_cast(node.level()); + for (level_t l = 1; l <= (std::min)(node_level, static_cast(max_level)); ++l) { + stats_per_level[l].nodes++; + stats_per_level[l].edges += neighbors_(node, l).size(); + stats_per_level[l].allocated_bytes += pre_.neighbors_bytes; + } + } + + // The `max_edges` parameter can be inferred from `nodes` + stats_per_level[0].max_edges = stats_per_level[0].nodes * config_.connectivity_base; + for (std::size_t l = 1; l <= max_level; ++l) + stats_per_level[l].max_edges = stats_per_level[l].nodes * config_.connectivity; + + // Aggregate stats across levels + stats_t result{}; + for (std::size_t l = 0; l <= max_level; ++l) + result.nodes += stats_per_level[l].nodes, // + result.edges += stats_per_level[l].edges, // + result.allocated_bytes += stats_per_level[l].allocated_bytes, // + result.max_edges += stats_per_level[l].max_edges; // + + return result; + } + + /** + * @brief A relatively accurate lower bound on the amount of memory consumed by the system. + * In practice it's error will be below 10%. + * + * @see `serialized_length` for the length of the binary serialized representation. + */ + std::size_t memory_usage(std::size_t allocator_entry_bytes = default_allocator_entry_bytes()) const noexcept { + std::size_t total = 0; + if (!viewed_file_) { + stats_t s = stats(); + total += s.allocated_bytes; + total += s.nodes * allocator_entry_bytes; + } + + // Temporary data-structures, proportional to the number of nodes: + total += limits_.members * sizeof(node_t) + allocator_entry_bytes; + + // Temporary data-structures, proportional to the number of threads: + total += limits_.threads() * sizeof(context_t) + allocator_entry_bytes * 3; + return total; + } + + std::size_t memory_usage_per_node(level_t level) const noexcept { return node_bytes_(level); } + + double inverse_log_connectivity() const { return pre_.inverse_log_connectivity; } + + std::size_t neighbors_base_bytes() const { return pre_.neighbors_base_bytes; } + + std::size_t neighbors_bytes() const { return pre_.neighbors_bytes; } + +#if defined(USEARCH_USE_PRAGMA_REGION) +#pragma endregion + +#pragma region Serialization +#endif + + /** + * @brief Estimate the binary length (in bytes) of the serialized index. + */ + std::size_t serialized_length() const noexcept { + std::size_t neighbors_length = 0; + for (std::size_t i = 0; i != size(); ++i) + neighbors_length += node_bytes_(node_at_(i).level()) + sizeof(level_t); + return sizeof(index_serialized_header_t) + neighbors_length; + } + + /** + * @brief Saves serialized binary index representation to a stream. + */ + template + serialization_result_t save_to_stream(output_callback_at&& output, progress_at&& progress = {}) const noexcept { + + serialization_result_t result; + + // Export some basic metadata + index_serialized_header_t header; + header.size = nodes_count_; + header.connectivity = config_.connectivity; + header.connectivity_base = config_.connectivity_base; + header.max_level = max_level_; + header.entry_slot = entry_slot_; + if (!output(&header, sizeof(header))) + return result.failed("Failed to serialize the header into stream"); + + // Progress status + std::size_t processed = 0; + std::size_t const total = 2 * header.size; + + // Export the number of levels per node + // That is both enough to estimate the overall memory consumption, + // and to be able to estimate the offsets of every entry in the file. + for (std::size_t i = 0; i != header.size; ++i) { + node_t node = node_at_(i); + level_t level = node.level(); + if (!output(&level, sizeof(level))) + return result.failed("Failed to serialize into stream"); + if (!progress(++processed, total)) + return result.failed("Terminated by user"); + } + + // After that dump the nodes themselves + for (std::size_t i = 0; i != header.size; ++i) { + span_bytes_t node_bytes = node_bytes_(node_at_(i)); + if (!output(node_bytes.data(), node_bytes.size())) + return result.failed("Failed to serialize into stream"); + if (!progress(++processed, total)) + return result.failed("Terminated by user"); + } + + return {}; + } + + /** + * @brief Symmetric to `save_from_stream`, pulls data from a stream. + */ + template + serialization_result_t load_from_stream(input_callback_at&& input, progress_at&& progress = {}) noexcept { + + serialization_result_t result; + + // Remove previously stored objects + index_limits_t old_limits = limits_; + reset(); + + // Pull basic metadata + index_serialized_header_t header; + if (!input(&header, sizeof(header))) + return result.failed("Failed to pull the header from the stream"); + + // We are loading an empty index, no more work to do + if (!header.size) { + reset(); + return result; + } + + // Allocate some dynamic memory to read all the levels + using levels_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt levels(header.size); + if (!levels) + return result.failed("Out of memory"); + if (!input(levels, header.size * sizeof(level_t))) + return result.failed("Failed to pull nodes levels from the stream"); + + // Submit metadata + config_.connectivity = header.connectivity; + config_.connectivity_base = header.connectivity_base; + error_t error = config_.validate(); + if (error) + return result.failed(std::move(error)); + + pre_ = precompute_(config_); + index_limits_t limits; + limits.members = header.size; + limits.threads_add = (std::max)(1, old_limits.threads_add); + limits.threads_search = (std::max)(1, old_limits.threads_search); + if (!reserve(limits)) { + reset(); + return result.failed("Out of memory"); + } + nodes_count_ = header.size; + max_level_ = static_cast(header.max_level); + entry_slot_ = static_cast(header.entry_slot); + + // Load the nodes + for (std::size_t i = 0; i != header.size; ++i) { + span_bytes_t node_bytes = node_malloc_(levels[i]); + if (!input(node_bytes.data(), node_bytes.size())) { + reset(); + return result.failed("Failed to pull nodes from the stream"); + } + nodes_[i] = node_t{node_bytes.data()}; + if (!progress(i + 1, header.size)) + return result.failed("Terminated by user"); + } + return {}; + } + + template + serialization_result_t save(char const* file_path, progress_at&& progress = {}) const noexcept { + return save(output_file_t(file_path), std::forward(progress)); + } + + template + serialization_result_t load(char const* file_path, progress_at&& progress = {}) noexcept { + return load(input_file_t(file_path), std::forward(progress)); + } + + /** + * @brief Saves serialized binary index representation to a file, generally on disk. + */ + template + serialization_result_t save(output_file_t file, progress_at&& progress = {}) const noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void* buffer, std::size_t length) { + io_result = file.write(buffer, length); + return !!io_result; + }, + std::forward(progress)); + + if (!stream_result) { + // Drop generic messages like "end of file reached" in favor + // of more specific messages from the stream + io_result.error.release(); + return stream_result; + } + return io_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t save(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) const noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(file.data() + offset, buffer, length); + offset += length; + return true; + }, + std::forward(progress)); + + return stream_result; + } + + /** + * @brief Loads the serialized binary index representation from disk to RAM. + * Adjusts the configuration properties of the constructed index to + * match the settings in the file. + */ + template + serialization_result_t load(input_file_t file, progress_at&& progress = {}) noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + io_result = file.read(buffer, length); + return !!io_result; + }, + std::forward(progress)); + + if (!stream_result) { + // Drop generic messages like "end of file reached" in favor + // of more specific messages from the stream + io_result.error.release(); + return stream_result; + } + return io_result; + } + + /** + * @brief Loads the serialized binary index representation from disk to RAM. + * Adjusts the configuration properties of the constructed index to + * match the settings in the file. + */ + template + serialization_result_t load(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) noexcept { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(buffer, file.data() + offset, length); + offset += length; + return true; + }, + std::forward(progress)); + + return stream_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t view(memory_mapped_file_t file, std::size_t offset = 0, + progress_at&& progress = {}) noexcept { + + // Remove previously stored objects + index_limits_t old_limits = limits_; + reset(); + + serialization_result_t result = file.open_if_not(); + if (!result) + return result; + + // Pull basic metadata + index_serialized_header_t header; + if (file.size() - offset < sizeof(header)) + return result.failed("File is corrupted and lacks a header"); + std::memcpy(&header, file.data() + offset, sizeof(header)); + + if (!header.size) { + reset(); + return result; + } + + // Precompute offsets of every node, but before that we need to update the configs + // This could have been done with `std::exclusive_scan`, but it's only available from C++17. + using offsets_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt offsets(header.size); + if (!offsets) + return result.failed("Out of memory"); + + config_.connectivity = header.connectivity; + config_.connectivity_base = header.connectivity_base; + error_t error = config_.validate(); + if (error) + return result.failed(std::move(error)); + + pre_ = precompute_(config_); + misaligned_ptr_gt levels{(byte_t*)file.data() + offset + sizeof(header)}; + offsets[0u] = offset + sizeof(header) + sizeof(level_t) * header.size; + for (std::size_t i = 1; i < header.size; ++i) + offsets[i] = offsets[i - 1] + node_bytes_(levels[i - 1]); + + std::size_t total_bytes = offsets[header.size - 1] + node_bytes_(levels[header.size - 1]); + if (file.size() < total_bytes) { + reset(); + return result.failed("File is corrupted and can't fit all the nodes"); + } + + // Submit metadata and reserve memory + index_limits_t limits; + limits.members = header.size; + limits.threads_add = (std::max)(1, old_limits.threads_add); + limits.threads_search = (std::max)(1, old_limits.threads_search); + if (!reserve(limits)) { + reset(); + return result.failed("Out of memory"); + } + nodes_count_ = header.size; + max_level_ = static_cast(header.max_level); + entry_slot_ = static_cast(header.entry_slot); + + // Rapidly address all the nodes + for (std::size_t i = 0; i != header.size; ++i) { + nodes_[i] = node_t{(byte_t*)file.data() + offsets[i]}; + if (!progress(i + 1, header.size)) + return result.failed("Terminated by user"); + } + viewed_file_ = std::move(file); + return {}; + } + +#if defined(USEARCH_USE_PRAGMA_REGION) +#pragma endregion +#endif + + /** + * @brief Performs compaction on the whole HNSW index, purging some entries + * and links to them, while also generating a more efficient mapping, + * putting the more frequently used entries closer together. + * + * @param[in] values A []-subscriptable object, providing access to the values. + * @param[in] metric Callable object measuring distance between any ::values and present objects. + * @param[in] slot_transition Callable object to inform changes in slot assignments. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + * @param[in] prefetch Callable object to prefetch data into the cache. + */ + template + void compact( // + values_at&& values, // + metric_at&& metric, // + slot_transition_at&& slot_transition, // + + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}, // + prefetch_at&& prefetch = prefetch_at{}) noexcept { + + // Export all the keys, slots, and levels. + // Partition them with the predicate. + // Sort the allowed entries in descending order of their level. + // Create a new array mapping old slots to the new ones (INT_MAX for deleted items). + struct slot_level_t { + compressed_slot_t old_slot; + compressed_slot_t cluster; + level_t level; + }; + using slot_level_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt slots_and_levels(size()); + + // Progress status + std::atomic do_tasks{true}; + std::atomic processed{0}; + std::size_t const total = 3 * slots_and_levels.size(); + + // For every bottom level node, determine its parent cluster + executor.dynamic(slots_and_levels.size(), [&](std::size_t thread_idx, std::size_t old_slot_as_uint) { + context_t& context = contexts_[thread_idx]; + compressed_slot_t old_slot = static_cast(old_slot_as_uint); + compressed_slot_t cluster = search_for_one_( // + values[citerator_at(old_slot)], // + metric, prefetch, // + static_cast(entry_slot_), max_level_, 0, context); + slots_and_levels[old_slot] = {old_slot, cluster, node_at_(old_slot).level()}; + ++processed; + if (thread_idx == 0) + do_tasks = progress(processed.load(), total); + return do_tasks.load(); + }); + if (!do_tasks.load()) + return; + + // Where the actual permutation happens: + std::sort(slots_and_levels.begin(), slots_and_levels.end(), [](slot_level_t const& a, slot_level_t const& b) { + return a.level == b.level ? a.cluster < b.cluster : a.level > b.level; + }); + + using size_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt old_slot_to_new(slots_and_levels.size()); + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) + old_slot_to_new[slots_and_levels[new_slot].old_slot] = new_slot; + + // Erase all the incoming links + buffer_gt reordered_nodes(slots_and_levels.size()); + tape_allocator_t reordered_tape; + + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { + std::size_t old_slot = slots_and_levels[new_slot].old_slot; + node_t old_node = node_at_(old_slot); + + std::size_t node_bytes = node_bytes_(old_node.level()); + byte_t* new_data = (byte_t*)reordered_tape.allocate(node_bytes); + node_t new_node{new_data}; + std::memcpy(new_data, old_node.tape(), node_bytes); + + for (level_t level = 0; level <= old_node.level(); ++level) + for (misaligned_ref_gt neighbor : neighbors_(new_node, level)) + neighbor = static_cast(old_slot_to_new[compressed_slot_t(neighbor)]); + + reordered_nodes[new_slot] = new_node; + if (!progress(++processed, total)) + return; + } + + for (std::size_t new_slot = 0; new_slot != slots_and_levels.size(); ++new_slot) { + std::size_t old_slot = slots_and_levels[new_slot].old_slot; + slot_transition(node_at_(old_slot).ckey(), // + static_cast(old_slot), // + static_cast(new_slot)); + if (!progress(++processed, total)) + return; + } + + nodes_ = std::move(reordered_nodes); + tape_allocator_ = std::move(reordered_tape); + entry_slot_ = old_slot_to_new[entry_slot_]; + } + + /** + * @brief Scans the whole collection, removing the links leading towards + * banned entries. This essentially isolates some nodes from the rest + * of the graph, while keeping their outgoing links, in case the node + * is structurally relevant and has a crucial role in the index. + * It won't reclaim the memory. + * + * @param[in] allow_member Predicate to mark nodes for isolation. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ + template < // + typename allow_member_at = dummy_predicate_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + void isolate( // + allow_member_at&& allow_member, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) noexcept { + + // Progress status + std::atomic do_tasks{true}; + std::atomic processed{0}; + + // Erase all the incoming links + std::size_t nodes_count = size(); + executor.dynamic(nodes_count, [&](std::size_t thread_idx, std::size_t node_idx) { + node_t node = node_at_(node_idx); + for (level_t level = 0; level <= node.level(); ++level) { + neighbors_ref_t neighbors = neighbors_(node, level); + neighbors.erase_if([&](compressed_slot_t neighbor_slot) { + node_t neighbor = node_at_(neighbor_slot); + return !allow_member(member_cref_t{neighbor.ckey(), neighbor_slot}); + }); + } + ++processed; + if (thread_idx == 0) + do_tasks = progress(processed.load(), nodes_count); + return do_tasks.load(); + }); + + // At the end report the latest numbers, because the reporter thread may be finished earlier + progress(processed.load(), nodes_count); + } + + private: + inline static precomputed_constants_t precompute_(index_config_t const& config) noexcept { + precomputed_constants_t pre; + pre.inverse_log_connectivity = 1.0 / std::log(static_cast(config.connectivity)); + pre.neighbors_bytes = config.connectivity * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); + pre.neighbors_base_bytes = config.connectivity_base * sizeof(compressed_slot_t) + sizeof(neighbors_count_t); + return pre; + } + + using span_bytes_t = span_gt; + + inline span_bytes_t node_bytes_(node_t node) const noexcept { return {node.tape(), node_bytes_(node.level())}; } + inline std::size_t node_bytes_(level_t level) const noexcept { + return node_head_bytes_() + node_neighbors_bytes_(level); + } + inline std::size_t node_neighbors_bytes_(node_t node) const noexcept { return node_neighbors_bytes_(node.level()); } + inline std::size_t node_neighbors_bytes_(level_t level) const noexcept { + return pre_.neighbors_base_bytes + pre_.neighbors_bytes * level; + } + + span_bytes_t node_malloc_(level_t level) noexcept { + std::size_t node_bytes = node_bytes_(level); + byte_t* data = (byte_t*)tape_allocator_.allocate(node_bytes); + return data ? span_bytes_t{data, node_bytes} : span_bytes_t{}; + } + + node_t node_make_(vector_key_t key, level_t level) noexcept { + span_bytes_t node_bytes = node_malloc_(level); + if (!node_bytes) + return {}; + + std::memset(node_bytes.data(), 0, node_bytes.size()); + node_t node{(byte_t*)node_bytes.data()}; + node.key(key); + node.level(level); + return node; + } + + node_t node_make_copy_(span_bytes_t old_bytes) noexcept { + byte_t* data = (byte_t*)tape_allocator_.allocate(old_bytes.size()); + if (!data) + return {}; + std::memcpy(data, old_bytes.data(), old_bytes.size()); + return node_t{data}; + } + + void node_free_(std::size_t idx) noexcept { + if (viewed_file_) + return; + + node_t& node = nodes_[idx]; + tape_allocator_.deallocate(node.tape(), node_bytes_(node).size()); + node = node_t{}; + } + + inline node_t node_at_(std::size_t idx) const noexcept { return nodes_[idx]; } + inline neighbors_ref_t neighbors_base_(node_t node) const noexcept { return {node.neighbors_tape()}; } + + inline neighbors_ref_t neighbors_non_base_(node_t node, level_t level) const noexcept { + usearch_assert_m(level > 0 && level <= node.level(), "Linking to missing level"); + return {node.neighbors_tape() + pre_.neighbors_base_bytes + (level - 1) * pre_.neighbors_bytes}; + } + + inline neighbors_ref_t neighbors_(node_t node, level_t level) const noexcept { + return level ? neighbors_non_base_(node, level) : neighbors_base_(node); + } + + struct node_lock_t { + nodes_mutexes_t& mutexes; + std::size_t slot; + inline ~node_lock_t() noexcept { mutexes.atomic_reset(slot); } + }; + + inline node_lock_t node_lock_(std::size_t slot) const noexcept { + while (nodes_mutexes_.atomic_set(slot)) + ; + return {nodes_mutexes_, slot}; + } + + struct node_conditional_lock_t { + nodes_mutexes_t& mutexes; + std::size_t slot; + inline ~node_conditional_lock_t() noexcept { + if (slot != std::numeric_limits::max()) + mutexes.atomic_reset(slot); + } + }; + + inline node_conditional_lock_t node_try_conditional_lock_(std::size_t slot, bool condition, + bool& failed_to_acquire) const noexcept { + failed_to_acquire = condition ? nodes_mutexes_.atomic_set(slot) : false; + return {nodes_mutexes_, failed_to_acquire ? std::numeric_limits::max() : slot}; + } + + template + candidates_view_t form_links_to_closest_( // + metric_at&& metric, std::size_t new_slot, level_t level, context_t& context) usearch_noexcept_m { + + node_t new_node = node_at_(new_slot); + top_candidates_t& top = context.top_candidates; + usearch_assert_m(top.size() || !require_non_empty_ak, "No candidates found"); + candidates_view_t top_view = + refine_(metric, config_.connectivity, top, context, context.computed_distances_in_refines); + usearch_assert_m(top_view.size() || !require_non_empty_ak, "This would lead to isolated nodes"); + + // Outgoing links from `new_slot`: + neighbors_ref_t new_neighbors = neighbors_(new_node, level); + usearch_assert_m(!new_neighbors.size(), "The newly inserted element should have blank link list"); + for (std::size_t idx = 0; idx != top_view.size(); idx++) { + usearch_assert_m(!new_neighbors[idx], "Possible memory corruption"); + usearch_assert_m(level <= node_at_(top_view[idx].slot).level(), "Linking to missing level"); + new_neighbors.push_back(top_view[idx].slot); + } + + return top_view; + } + + template + void form_reverse_links_( // + metric_at&& metric, compressed_slot_t new_slot, candidates_view_t new_neighbors, value_at&& value, + level_t level, context_t& context) usearch_noexcept_m { + + top_candidates_t& top_for_refine = context.top_for_refine; + std::size_t const connectivity_max = level ? config_.connectivity : config_.connectivity_base; + + // Reverse links from the neighbors: + for (auto new_neighbor : new_neighbors) { + compressed_slot_t close_slot = new_neighbor.slot; + if (close_slot == new_slot) + continue; + node_lock_t close_lock = node_lock_(close_slot); + node_t close_node = node_at_(close_slot); + neighbors_ref_t close_header = neighbors_(close_node, level); + + // The node may have no neighbors only in one case, when it's the first one in the index, + // but that is problematic to track in multi-threaded environments, where the order of insertion + // is not guaranteed. + // usearch_assert_m(close_header.size() || new_slot == 1, "Possible corruption - isolated node"); + usearch_assert_m(close_header.size() <= connectivity_max, "Possible corruption - overflow"); + usearch_assert_m(close_slot != new_slot, "Self-loops are impossible"); + usearch_assert_m(level <= close_node.level(), "Linking to missing level"); + + // If `new_slot` is already present in the neighboring connections of `close_slot` + // then no need to modify any connections or run the heuristics. + if (close_header.size() < connectivity_max) { + close_header.push_back(new_slot); + continue; + } + + top_for_refine.clear(); + top_for_refine.insert_reserved({context.measure(value, citerator_at(close_slot), metric), new_slot}); + for (compressed_slot_t successor_slot : close_header) + top_for_refine.insert_reserved( + {context.measure(citerator_at(close_slot), citerator_at(successor_slot), metric), successor_slot}); + + // Export the results: + close_header.clear(); + candidates_view_t top_view = refine_(metric, connectivity_max, top_for_refine, context, + context.computed_distances_in_reverse_refines); + usearch_assert_m(top_view.size(), "This would lead to isolated nodes"); + for (std::size_t idx = 0; idx != top_view.size(); idx++) + close_header.push_back(top_view[idx].slot); + } + } + + level_t choose_random_level_(std::default_random_engine& level_generator) const noexcept { + std::uniform_real_distribution distribution(0.0, 1.0); + double r = -std::log(distribution(level_generator)) * pre_.inverse_log_connectivity; + return (level_t)r; + } + + struct candidates_range_t; + class candidates_iterator_t { + friend struct candidates_range_t; + + index_gt const& index_; + neighbors_ref_t neighbors_; + visits_hash_set_t& visits_; + std::size_t current_; + + candidates_iterator_t& skip_missing() noexcept { + if (!visits_.size()) + return *this; + while (current_ != neighbors_.size()) { + compressed_slot_t neighbor_slot = neighbors_[current_]; + if (visits_.test(neighbor_slot)) + current_++; + else + break; + } + return *this; + } + + public: + using element_t = compressed_slot_t; + using iterator_category = std::forward_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = misaligned_ptr_gt; + using reference = misaligned_ref_gt; + + value_type operator*() const noexcept { return neighbors_[current_]; } + candidates_iterator_t(index_gt const& index, neighbors_ref_t neighbors, visits_hash_set_t& visits, + std::size_t progress) noexcept + : index_(index), neighbors_(neighbors), visits_(visits), current_(progress) {} + candidates_iterator_t operator++(int) noexcept { + return candidates_iterator_t(index_, neighbors_, visits_, current_ + 1).skip_missing(); + } + candidates_iterator_t& operator++() noexcept { + ++current_; + skip_missing(); + return *this; + } + bool operator==(candidates_iterator_t const& other) noexcept { return current_ == other.current_; } + bool operator!=(candidates_iterator_t const& other) noexcept { return current_ != other.current_; } + + vector_key_t key() const noexcept { return index_.node_at_(slot()).key(); } + compressed_slot_t slot() const noexcept { return neighbors_[current_]; } + friend inline std::size_t get_slot(candidates_iterator_t const& it) noexcept { return it.slot(); } + friend inline vector_key_t get_key(candidates_iterator_t const& it) noexcept { return it.key(); } + }; + + struct candidates_range_t { + index_gt const& index; + neighbors_ref_t neighbors; + visits_hash_set_t& visits; + + candidates_iterator_t begin() const noexcept { + return candidates_iterator_t{index, neighbors, visits, 0}.skip_missing(); + } + candidates_iterator_t end() const noexcept { return {index, neighbors, visits, neighbors.size()}; } + }; + + template + compressed_slot_t search_for_one_( // + value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // + compressed_slot_t closest_slot, level_t begin_level, level_t end_level, context_t& context) const noexcept { + + visits_hash_set_t& visits = context.visits; + visits.clear(); + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(closest_slot), citerator_at(closest_slot) + 1); + + distance_t closest_dist = context.measure(query, citerator_at(closest_slot), metric); + for (level_t level = begin_level; level > end_level; --level) { + bool changed; + do { + changed = false; + node_lock_t closest_lock = node_lock_(closest_slot); + neighbors_ref_t closest_neighbors = neighbors_non_base_(node_at_(closest_slot), level); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, closest_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Actual traversal + for (compressed_slot_t candidate_slot : closest_neighbors) { + distance_t candidate_dist = context.measure(query, citerator_at(candidate_slot), metric); + if (candidate_dist < closest_dist) { + closest_dist = candidate_dist; + closest_slot = candidate_slot; + changed = true; + } + } + + context.iteration_cycles++; + } while (changed); + } + return closest_slot; + } + + /** + * @brief Traverses a layer of a graph, to find the best place to insert a new node. + * Locks the nodes in the process, assuming other threads are updating neighbors lists. + * @return `true` if procedure succeeded, `false` if run out of memory. + */ + template + bool search_to_insert_( // + value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // + compressed_slot_t start_slot, level_t level, std::size_t top_limit, context_t& context) noexcept { + + visits_hash_set_t& visits = context.visits; + next_candidates_t& next = context.next_candidates; // pop min, push + top_candidates_t& top = context.top_candidates; // pop max, push + + visits.clear(); + next.clear(); + top.clear(); + + // At the very least we are going to explore the starting node and its neighbors + if (!visits.reserve(config_.connectivity_base + 1u)) + return false; + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(start_slot), citerator_at(start_slot) + 1); + + distance_t radius = context.measure(query, citerator_at(start_slot), metric); + next.insert_reserved({-radius, start_slot}); + top.insert_reserved({radius, start_slot}); + visits.set(start_slot); + + // The primary loop of the graph traversal + while (!next.empty()) { + + candidate_t candidacy = next.top(); + if ((-candidacy.distance) > radius && top.size() == top_limit) + break; + + next.pop(); + context.iteration_cycles++; + + compressed_slot_t candidate_slot = candidacy.slot; + node_t candidate_ref = node_at_(candidate_slot); + node_lock_t candidate_lock = node_lock_(candidate_slot); + neighbors_ref_t candidate_neighbors = neighbors_(candidate_ref, level); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Assume the worst-case when reserving memory + if (!visits.reserve(visits.size() + candidate_neighbors.size())) + return false; + + for (compressed_slot_t successor_slot : candidate_neighbors) { + if (visits.set(successor_slot)) + continue; + + // We don't access the neighbors of the `successor_slot` node, + // so we don't have to lock it. + // node_lock_t successor_lock = node_lock_(successor_slot); + distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); + if (top.size() < top_limit || successor_dist < radius) { + // This can substantially grow our priority queue: + next.insert({-successor_dist, successor_slot}); + // This will automatically evict poor matches: + top.insert({successor_dist, successor_slot}, top_limit); + radius = top.top().distance; + } + } + } + return true; + } + + /** + * @brief Traverses a layer of a graph, to find the best neighbors list for updated node. + * Locks the nodes in the process, assuming other threads are updating neighbors lists. + * @return `true` if procedure succeeded, `false` if run out of memory. + */ + template + bool search_to_update_( // + value_at&& query, metric_at&& metric, prefetch_at&& prefetch, // + compressed_slot_t start_slot, compressed_slot_t updated_slot, level_t level, std::size_t top_limit, + context_t& context) noexcept { + + visits_hash_set_t& visits = context.visits; + next_candidates_t& next = context.next_candidates; // pop min, push + top_candidates_t& top = context.top_candidates; // pop max, push + + visits.clear(); + next.clear(); + top.clear(); + + // At the very least we are going to explore the starting node and its neighbors + if (!visits.reserve(config_.connectivity_base + 1u)) + return false; + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(start_slot), citerator_at(start_slot) + 1); + + distance_t radius = context.measure(query, citerator_at(start_slot), metric); + next.insert_reserved({-radius, start_slot}); + visits.set(start_slot); + if (start_slot != updated_slot) + top.insert_reserved({radius, start_slot}); + + // The primary loop of the graph traversal + while (!next.empty()) { + + candidate_t candidacy = next.top(); + if ((-candidacy.distance) > radius && top.size() == top_limit) + break; + + next.pop(); + context.iteration_cycles++; + + compressed_slot_t candidate_slot = candidacy.slot; + node_t candidate_ref = node_at_(candidate_slot); + + // The trickiest part of update-heavy workloads is mitigating dead-locks + // in connected nodes during traversal. A "good enough" solution would be + // to skip concurrent access, assuming the other "close" node is gonna add + // this one when forming reverse connections. + bool failed_to_acquire = false; + node_conditional_lock_t candidate_lock = + node_try_conditional_lock_(candidate_slot, updated_slot != candidate_slot, failed_to_acquire); + if (failed_to_acquire) + continue; + neighbors_ref_t candidate_neighbors = neighbors_(candidate_ref, level); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Assume the worst-case when reserving memory + if (!visits.reserve(visits.size() + candidate_neighbors.size())) + return false; + + for (compressed_slot_t successor_slot : candidate_neighbors) { + if (visits.set(successor_slot)) + continue; + + // We don't access the neighbors of the `successor_slot` node, + // so we don't have to lock it. + // node_conditional_lock_t successor_lock = + // node_try_conditional_lock_(successor_slot, updated_slot != successor_slot); + distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); + if (top.size() < top_limit || successor_dist < radius) { + // This can substantially grow our priority queue: + next.insert({-successor_dist, successor_slot}); + // This will automatically evict poor matches: + if (updated_slot != successor_slot) + top.insert({successor_dist, successor_slot}, top_limit); + radius = top.top().distance; + } + } + } + return true; + } + + /** + * @brief Traverses the @b base layer of a graph, to find a close match. + * Doesn't lock any nodes, assuming read-only simultaneous access. + * @return `true` if procedure succeeded, `false` if run out of memory. + */ + template + bool search_to_find_in_base_( // + value_at&& query, metric_at&& metric, predicate_at&& predicate, prefetch_at&& prefetch, // + compressed_slot_t start_slot, std::size_t expansion, context_t& context) const usearch_noexcept_m { + + visits_hash_set_t& visits = context.visits; + next_candidates_t& next = context.next_candidates; // pop min, push + top_candidates_t& top = context.top_candidates; // pop max, push + std::size_t const top_limit = expansion; + + visits.clear(); + next.clear(); + top.clear(); + if (!visits.reserve(config_.connectivity_base + 1u)) + return false; + + // Optional prefetching + if (!is_dummy()) + prefetch(citerator_at(start_slot), citerator_at(start_slot) + 1); + + distance_t radius = context.measure(query, citerator_at(start_slot), metric); + usearch_assert_m(next.capacity(), "The `max_heap_gt` must have been reserved in the search entry point"); + next.insert_reserved({-radius, start_slot}); + visits.set(start_slot); + + // Don't populate the top list if the predicate is not satisfied + if (is_dummy() || predicate(member_cref_t{node_at_(start_slot).ckey(), start_slot})) { + usearch_assert_m(top.capacity(), + "The `sorted_buffer_gt` must have been reserved in the search entry point"); + top.insert_reserved({radius, start_slot}); + } + + while (!next.empty()) { + + candidate_t candidate = next.top(); + if ((-candidate.distance) > radius && top.size() == top_limit) + break; + + next.pop(); + context.iteration_cycles++; + + neighbors_ref_t candidate_neighbors = neighbors_base_(node_at_(candidate.slot)); + + // Optional prefetching + if (!is_dummy()) { + candidates_range_t missing_candidates{*this, candidate_neighbors, visits}; + prefetch(missing_candidates.begin(), missing_candidates.end()); + } + + // Assume the worst-case when reserving memory + if (!visits.reserve(visits.size() + candidate_neighbors.size())) + return false; + + for (compressed_slot_t successor_slot : candidate_neighbors) { + if (visits.set(successor_slot)) + continue; + + distance_t successor_dist = context.measure(query, citerator_at(successor_slot), metric); + if (top.size() < top_limit || successor_dist < radius) { + // This can substantially grow our priority queue: + next.insert({-successor_dist, successor_slot}); + if (is_dummy() || + predicate(member_cref_t{node_at_(successor_slot).ckey(), successor_slot})) { + top.insert({successor_dist, successor_slot}, top_limit); + radius = top.top().distance; + } + } + } + } + + return true; + } + + /** + * @brief Iterates through all members, without actually touching the index. + */ + template + void search_exact_( // + value_at&& query, metric_at&& metric, predicate_at&& predicate, // + std::size_t count, context_t& context) const noexcept { + + top_candidates_t& top = context.top_candidates; + top.clear(); + top.reserve(count); + for (std::size_t i = 0; i != size(); ++i) { + auto slot = static_cast(i); + if (!is_dummy()) + if (!predicate(at(slot))) + continue; + + distance_t distance = context.measure(query, citerator_at(slot), metric); + top.insert(candidate_t{distance, slot}, count); + } + } + + /** + * @brief This algorithm from the original paper implements a heuristic, + * that massively reduces the number of connections a point has, + * to keep only the neighbors, that are from each other. + */ + template + candidates_view_t refine_( // + metric_at&& metric, // + std::size_t needed, top_candidates_t& top, context_t& context, // + std::size_t& refines_counter) const noexcept { + + // Avoid expensive computation, if the set is already small + candidate_t* top_data = top.data(); + std::size_t const top_count = top.size(); + if (top_count < needed) + return {top_data, top_count}; + + // Sort before processing + top.sort_ascending(); + + std::size_t submitted_count = 1; + std::size_t consumed_count = 1; /// Always equal or greater than `submitted_count`. + while (submitted_count < needed && consumed_count < top_count) { + candidate_t candidate = top_data[consumed_count]; + bool good = true; + std::size_t idx = 0; + for (; idx < submitted_count; idx++) { + candidate_t submitted = top_data[idx]; + distance_t inter_result_dist = context.measure( // + citerator_at(candidate.slot), // + citerator_at(submitted.slot), // + metric); + if (inter_result_dist < candidate.distance) { + good = false; + break; + } + } + refines_counter += idx; + + if (good) { + top_data[submitted_count] = top_data[consumed_count]; + submitted_count++; + } + consumed_count++; + } + + top.shrink(submitted_count); + return {top_data, submitted_count}; + } +}; + +struct join_result_t { + error_t error{}; + std::size_t intersection_size{}; + std::size_t engagements{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + + explicit operator bool() const noexcept { return !error; } + join_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } +}; + +/** + * @brief Adapts the Male-Optimal Stable Marriage algorithm for unequal sets + * to perform fast one-to-one matching between two large collections + * of vectors, using approximate nearest neighbors search. + * + * @param[inout] man_to_woman Container to map ::men keys to ::women. + * @param[inout] woman_to_man Container to map ::women keys to ::men. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ +template < // + + typename men_at, // + typename women_at, // + typename men_values_at, // + typename women_values_at, // + typename men_metric_at, // + typename women_metric_at, // + + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > +static join_result_t join( // + men_at const& men, // + women_at const& women, // + men_values_at const& men_values, // + women_values_at const& women_values, // + men_metric_at&& men_metric, // + women_metric_at&& women_metric, // + + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) noexcept { + + if (women.size() < men.size()) + return unum::usearch::join( // + women, men, // + women_values, men_values, // + std::forward(women_metric), std::forward(men_metric), // + + config, // + std::forward(woman_to_man), // + std::forward(man_to_woman), // + std::forward(executor), // + std::forward(progress)); + + join_result_t result; + + // Sanity checks and argument validation: + if (&men == &women) + return result.failed("Can't join with itself, consider copying"); + + if (config.max_proposals == 0) + config.max_proposals = std::log(men.size()) + executor.size(); + + using proposals_count_t = std::uint16_t; + config.max_proposals = (std::min)(men.size(), config.max_proposals); + + using distance_t = typename men_at::distance_t; + using dynamic_allocator_traits_t = typename men_at::dynamic_allocator_traits_t; + using man_key_t = typename men_at::vector_key_t; + using woman_key_t = typename women_at::vector_key_t; + + // Use the `compressed_slot_t` type of the larger collection + using compressed_slot_t = typename women_at::compressed_slot_t; + using compressed_slot_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + using proposals_count_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + + // Create an atomic queue, as a ring structure, from/to which + // free men will be added/pulled. + std::mutex free_men_mutex{}; + ring_gt free_men; + free_men.reserve(men.size()); + for (std::size_t i = 0; i != men.size(); ++i) + free_men.push(static_cast(i)); + + // We are gonna need some temporary memory. + buffer_gt proposal_counts(men.size()); + buffer_gt man_to_woman_slots(men.size()); + buffer_gt woman_to_man_slots(women.size()); + if (!proposal_counts || !man_to_woman_slots || !woman_to_man_slots) + return result.failed("Can't temporary mappings"); + + compressed_slot_t missing_slot; + std::memset((void*)&missing_slot, 0xFF, sizeof(compressed_slot_t)); + std::memset((void*)man_to_woman_slots.data(), 0xFF, sizeof(compressed_slot_t) * men.size()); + std::memset((void*)woman_to_man_slots.data(), 0xFF, sizeof(compressed_slot_t) * women.size()); + std::memset(proposal_counts.data(), 0, sizeof(proposals_count_t) * men.size()); + + // Define locks, to limit concurrent accesses to `man_to_woman_slots` and `woman_to_man_slots`. + bitset_t men_locks(men.size()), women_locks(women.size()); + if (!men_locks || !women_locks) + return result.failed("Can't allocate locks"); + + std::atomic rounds{0}; + std::atomic engagements{0}; + std::atomic computed_distances{0}; + std::atomic visited_members{0}; + std::atomic atomic_error{nullptr}; + + // Concurrently process all the men + executor.parallel([&](std::size_t thread_idx) { + index_search_config_t search_config; + search_config.expansion = config.expansion; + search_config.exact = config.exact; + search_config.thread = thread_idx; + compressed_slot_t free_man_slot; + + // While there exist a free man who still has a woman to propose to. + while (!atomic_error.load(std::memory_order_relaxed)) { + std::size_t passed_rounds = 0; + std::size_t total_rounds = 0; + { + std::unique_lock pop_lock(free_men_mutex); + if (!free_men.try_pop(free_man_slot)) + // Primary exit path, we have exhausted the list of candidates + break; + passed_rounds = ++rounds; + total_rounds = passed_rounds + free_men.size(); + } + if (thread_idx == 0 && !progress(passed_rounds, total_rounds)) { + atomic_error.store("Terminated by user"); + break; + } + while (men_locks.atomic_set(free_man_slot)) + ; + + proposals_count_t& free_man_proposals = proposal_counts[free_man_slot]; + if (free_man_proposals >= config.max_proposals) + continue; + + // Find the closest woman, to whom this man hasn't proposed yet. + ++free_man_proposals; + auto candidates = women.search(men_values[free_man_slot], free_man_proposals, women_metric, search_config); + visited_members += candidates.visited_members; + computed_distances += candidates.computed_distances; + if (!candidates) { + atomic_error = candidates.error.release(); + break; + } + + auto match = candidates.back(); + auto woman = match.member; + while (women_locks.atomic_set(woman.slot)) + ; + + compressed_slot_t husband_slot = woman_to_man_slots[woman.slot]; + bool woman_is_free = husband_slot == missing_slot; + if (woman_is_free) { + // Engagement + man_to_woman_slots[free_man_slot] = static_cast(woman.slot); + woman_to_man_slots[woman.slot] = free_man_slot; + engagements++; + } else { + distance_t distance_from_husband = + women_metric(women_values[static_cast(woman.slot)], men_values[husband_slot]); + distance_t distance_from_candidate = match.distance; + if (distance_from_husband > distance_from_candidate) { + // Break-up + while (men_locks.atomic_set(husband_slot)) + ; + man_to_woman_slots[husband_slot] = missing_slot; + men_locks.atomic_reset(husband_slot); + + // New Engagement + man_to_woman_slots[free_man_slot] = static_cast(woman.slot); + woman_to_man_slots[woman.slot] = free_man_slot; + engagements++; + + std::unique_lock push_lock(free_men_mutex); + free_men.push(husband_slot); + } else { + std::unique_lock push_lock(free_men_mutex); + free_men.push(free_man_slot); + } + } + + men_locks.atomic_reset(free_man_slot); + women_locks.atomic_reset(woman.slot); + } + }); + + if (atomic_error) + return result.failed(atomic_error.load()); + + // Export the "slots" into keys: + std::size_t intersection_size = 0; + for (std::size_t man_slot = 0; man_slot != men.size(); ++man_slot) { + compressed_slot_t woman_slot = man_to_woman_slots[man_slot]; + if (woman_slot != missing_slot) { + man_key_t man = men.at(static_cast(man_slot)).key; + woman_key_t woman = women.at(woman_slot).key; + man_to_woman[man] = woman; + woman_to_man[woman] = man; + intersection_size++; + } + } + + // Export stats + result.engagements = engagements; + result.intersection_size = intersection_size; + result.computed_distances = computed_distances; + result.visited_members = visited_members; + return result; +} + +} // namespace usearch +} // namespace unum + +#endif diff --git a/zig/usearch/include/index_dense.hpp b/zig/usearch/include/index_dense.hpp new file mode 100644 index 000000000..43e9b5ede --- /dev/null +++ b/zig/usearch/include/index_dense.hpp @@ -0,0 +1,2273 @@ +/** + * @file index_dense.hpp + * @author Ash Vardanian + * @brief Single-header Vector Search engine for equi-dimensional dense vectors. + * @date July 26, 2023 + */ +#pragma once +#include "index.hpp" +#include // `aligned_alloc` + +#include "index.hpp" +#include "index_plugins.hpp" + +#if defined(USEARCH_DEFINED_CPP17) +#include // `std::shared_mutex` +#endif + +namespace unum { +namespace usearch { + +template class index_dense_gt; + +/** + * @brief The "magic" sequence helps infer the type of the file. + * USearch indexes start with the "usearch" string. + */ +constexpr char const* default_magic() { return "usearch"; } + +using index_dense_head_buffer_t = byte_t[64]; + +static_assert(sizeof(index_dense_head_buffer_t) == 64, "File header should be exactly 64 bytes"); + +/** + * @brief Serialized binary representations of the USearch index start with metadata. + * Metadata is parsed into a `index_dense_head_t`, containing the USearch package version, + * and the properties of the index. + * + * It uses: 13 bytes for file versioning, 22 bytes for structural information = 35 bytes. + * The following 24 bytes contain binary size of the graph, of the vectors, and the checksum, + * leaving 5 bytes at the end vacant. + */ +struct index_dense_head_t { + + // Versioning: + using magic_t = char[7]; + using version_t = std::uint16_t; + + // Versioning: 7 + 2 * 3 = 13 bytes + char const* magic; + misaligned_ref_gt version_major; + misaligned_ref_gt version_minor; + misaligned_ref_gt version_patch; + + // Structural: 4 * 3 = 12 bytes + misaligned_ref_gt kind_metric; + misaligned_ref_gt kind_scalar; + misaligned_ref_gt kind_key; + misaligned_ref_gt kind_compressed_slot; + + // Population: 8 * 3 = 24 bytes + misaligned_ref_gt count_present; + misaligned_ref_gt count_deleted; + misaligned_ref_gt dimensions; + misaligned_ref_gt multi; + + index_dense_head_t(byte_t* ptr) noexcept + : magic((char const*)exchange(ptr, ptr + sizeof(magic_t))), // + version_major(exchange(ptr, ptr + sizeof(version_t))), // + version_minor(exchange(ptr, ptr + sizeof(version_t))), // + version_patch(exchange(ptr, ptr + sizeof(version_t))), // + kind_metric(exchange(ptr, ptr + sizeof(metric_kind_t))), // + kind_scalar(exchange(ptr, ptr + sizeof(scalar_kind_t))), // + kind_key(exchange(ptr, ptr + sizeof(scalar_kind_t))), // + kind_compressed_slot(exchange(ptr, ptr + sizeof(scalar_kind_t))), // + count_present(exchange(ptr, ptr + sizeof(std::uint64_t))), // + count_deleted(exchange(ptr, ptr + sizeof(std::uint64_t))), // + dimensions(exchange(ptr, ptr + sizeof(std::uint64_t))), // + multi(exchange(ptr, ptr + sizeof(bool))) {} +}; + +struct index_dense_head_result_t { + + index_dense_head_buffer_t buffer; + index_dense_head_t head; + error_t error; + + explicit operator bool() const noexcept { return !error; } + index_dense_head_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } +}; + +/** + * @brief Configuration settings for the construction of dense + * equidimensional vector indexes. + * + * Unlike the underlying `index_gt` class, incorporates the + * `::expansion_add` and `::expansion_search` parameters passed + * separately for the lower-level engine. + */ +struct index_dense_config_t : public index_config_t { + std::size_t expansion_add = default_expansion_add(); + std::size_t expansion_search = default_expansion_search(); + + /** + * @brief Excludes vectors from the serialized file. + * This is handy when you want to store the vectors in a separate file. + * + * ! For advanced users only. + */ + bool exclude_vectors = false; + + /** + * @brief Allows you to store multiple vectors per key. + * This is handy when a large document is chunked into many parts. + * + * ! May degrade the performance of iterators. + */ + bool multi = false; + + /** + * @brief Allows you to reduce RAM consumption by avoiding + * reverse-indexing keys-to-vectors, and only keeping + * the vectors-to-keys mappings. + * + * ! This configuration parameter doesn't affect the serialized file, + * ! and is not preserved between runs. Makes sense for smaller vectors + * ! that fit in a couple of cache lines. + * + * The trade-off is that some methods won't be available, like `get`, `rename`, + * and `remove`. The basic functionality, like `add` and `search` will work as + * expected even with `enable_key_lookups = false`. + * + * If both `!multi && !enable_key_lookups`, the "duplicate entry" checks won't + * be performed and no errors will be raised. + */ + bool enable_key_lookups = true; + + inline index_dense_config_t(index_config_t base) noexcept : index_config_t(base) {} + + inline index_dense_config_t(std::size_t c = 0, std::size_t ea = 0, std::size_t es = 0) noexcept + : index_config_t(c), expansion_add(ea), expansion_search(es) {} + + /** + * @brief Validates the configuration settings, updating them in-place. + * @return Error message, if any. + */ + inline error_t validate() noexcept { + error_t error = index_config_t::validate(); + if (error) + return error; + if (expansion_add == 0) + expansion_add = default_expansion_add(); + if (expansion_search == 0) + expansion_search = default_expansion_search(); + return {}; + } +}; + +struct index_dense_clustering_config_t { + std::size_t min_clusters = 0; + std::size_t max_clusters = 0; + enum mode_t { + merge_smallest_k, + merge_closest_k, + } mode = merge_smallest_k; +}; + +struct index_dense_serialization_config_t { + bool exclude_vectors = false; + bool use_64_bit_dimensions = false; +}; + +struct index_dense_copy_config_t : public index_copy_config_t { + bool force_vector_copy = true; + + index_dense_copy_config_t() = default; + index_dense_copy_config_t(index_copy_config_t base) noexcept : index_copy_config_t(base) {} +}; + +struct index_dense_metadata_result_t { + index_dense_serialization_config_t config; + index_dense_head_buffer_t head_buffer; + index_dense_head_t head; + error_t error; + + explicit operator bool() const noexcept { return !error; } + index_dense_metadata_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + + index_dense_metadata_result_t() noexcept : config(), head_buffer(), head(head_buffer), error() {} + + index_dense_metadata_result_t(index_dense_metadata_result_t&& other) noexcept + : config(), head_buffer(), head(head_buffer), error(std::move(other.error)) { + std::memcpy(&config, &other.config, sizeof(other.config)); + std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); + } + + index_dense_metadata_result_t& operator=(index_dense_metadata_result_t&& other) noexcept { + std::memcpy(&config, &other.config, sizeof(other.config)); + std::memcpy(&head_buffer, &other.head_buffer, sizeof(other.head_buffer)); + error = std::move(other.error); + return *this; + } +}; + +/** + * @brief Fixes serialized scalar-kind codes for pre-v2.10 versions, until we can upgrade to v3. + * The old enum `scalar_kind_t` is defined without explicit constants from 0. + */ +inline scalar_kind_t convert_pre_2_10_scalar_kind(scalar_kind_t scalar_kind) noexcept { + switch (static_cast::type>(scalar_kind)) { + case 0: return scalar_kind_t::unknown_k; + case 1: return scalar_kind_t::b1x8_k; + case 2: return scalar_kind_t::u40_k; + case 3: return scalar_kind_t::uuid_k; + case 4: return scalar_kind_t::f64_k; + case 5: return scalar_kind_t::f32_k; + case 6: return scalar_kind_t::f16_k; + case 7: return scalar_kind_t::f8_k; + case 8: return scalar_kind_t::u64_k; + case 9: return scalar_kind_t::u32_k; + case 10: return scalar_kind_t::u8_k; + case 11: return scalar_kind_t::i64_k; + case 12: return scalar_kind_t::i32_k; + case 13: return scalar_kind_t::i16_k; + case 14: return scalar_kind_t::i8_k; + default: return scalar_kind; + } +} + +/** + * @brief Fixes the metadata for pre-v2.10 versions, until we can upgrade to v3. + * Originates from: https://github.com/unum-cloud/usearch/issues/423 + */ +inline void fix_pre_2_10_metadata(index_dense_head_t& head) { + if (head.version_major == 2 && head.version_minor < 10) { + head.kind_scalar = convert_pre_2_10_scalar_kind(head.kind_scalar); + head.kind_key = convert_pre_2_10_scalar_kind(head.kind_key); + head.kind_compressed_slot = convert_pre_2_10_scalar_kind(head.kind_compressed_slot); + head.version_minor = 10; + head.version_patch = 0; + } +} + +/** + * @brief Extracts metadata from a pre-constructed index on disk, + * without loading it or mapping the whole binary file. + */ +inline index_dense_metadata_result_t index_dense_metadata_from_path(char const* file_path) noexcept { + index_dense_metadata_result_t result; + std::unique_ptr file(std::fopen(file_path, "rb"), &std::fclose); + if (!file) + return result.failed(std::strerror(errno)); + + // Read the header + std::size_t read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + // Check if the file immediately starts with the index, instead of vectors + result.config.exclude_vectors = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) { + fix_pre_2_10_metadata(result.head); + return result; + } + + if (std::fseek(file.get(), 0L, SEEK_END) != 0) + return result.failed("Can't infer file size"); + + // Check if it starts with 32-bit + std::size_t const file_size = std::ftell(file.get()); + + std::uint32_t dimensions_u32[2]{0}; + std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); + std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); + + std::uint64_t dimensions_u64[2]{0}; + std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); + std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); + + // Check if it starts with 32-bit + if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) { + if (std::fseek(file.get(), static_cast(offset_if_u32), SEEK_SET) != 0) + return result.failed(std::strerror(errno)); + read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = false; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) { + fix_pre_2_10_metadata(result.head); + return result; + } + } + + // Check if it starts with 64-bit + if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) { + if (std::fseek(file.get(), static_cast(offset_if_u64), SEEK_SET) != 0) + return result.failed(std::strerror(errno)); + read = std::fread(result.head_buffer, sizeof(index_dense_head_buffer_t), 1, file.get()); + if (!read) + return result.failed(std::feof(file.get()) ? "End of file reached!" : std::strerror(errno)); + + // Check if it starts with 64-bit + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) { + fix_pre_2_10_metadata(result.head); + return result; + } + } + + return result.failed("Not a dense USearch index!"); +} + +/** + * @brief Extracts metadata from a pre-constructed index serialized into an in-memory buffer. + */ +inline index_dense_metadata_result_t index_dense_metadata_from_buffer(memory_mapped_file_t const& file, + std::size_t offset = 0) noexcept { + index_dense_metadata_result_t result; + + // Read the header + if (offset + sizeof(index_dense_head_buffer_t) >= file.size()) + return result.failed("End of file reached!"); + + byte_t const* file_data = file.data() + offset; + std::size_t const file_size = file.size() - offset; + std::memcpy(&result.head_buffer, file_data, sizeof(index_dense_head_buffer_t)); + + // Check if the file immediately starts with the index, instead of vectors + result.config.exclude_vectors = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + + // Check if it starts with 32-bit + std::uint32_t dimensions_u32[2]{0}; + std::memcpy(dimensions_u32, result.head_buffer, sizeof(dimensions_u32)); + std::size_t offset_if_u32 = std::size_t(dimensions_u32[0]) * dimensions_u32[1] + sizeof(dimensions_u32); + + std::uint64_t dimensions_u64[2]{0}; + std::memcpy(dimensions_u64, result.head_buffer, sizeof(dimensions_u64)); + std::size_t offset_if_u64 = std::size_t(dimensions_u64[0]) * dimensions_u64[1] + sizeof(dimensions_u64); + + // Check if it starts with 32-bit + if (offset_if_u32 + sizeof(index_dense_head_buffer_t) < file_size) { + std::memcpy(&result.head_buffer, file_data + offset_if_u32, sizeof(index_dense_head_buffer_t)); + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = false; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + // Check if it starts with 64-bit + if (offset_if_u64 + sizeof(index_dense_head_buffer_t) < file_size) { + std::memcpy(&result.head_buffer, file_data + offset_if_u64, sizeof(index_dense_head_buffer_t)); + result.config.exclude_vectors = false; + result.config.use_64_bit_dimensions = true; + if (std::memcmp(result.head_buffer, default_magic(), std::strlen(default_magic())) == 0) + return result; + } + + return result.failed("Not a dense USearch index!"); +} + +/** + * @brief Oversimplified type-punned index for equidimensional vectors + * with automatic @b down-casting, hardware-specific @b SIMD metrics, + * and ability to @b remove existing vectors, common in Semantic Caching + * applications. + * + * @section Serialization + * + * The serialized binary form of `index_dense_gt` is made up of three parts: + * 1. Binary matrix, aka the `.bbin` part, + * 2. Metadata about used metrics, number of used vs free slots, + * 3. The HNSW index in a binary form. + * The first (1.) generally starts with 2 integers - number of rows (vectors) and @b single-byte columns. + * The second (2.) starts with @b "usearch"-magic-string, used to infer the file type on open. + * The third (3.) is implemented by the underlying `index_gt` class. + */ +template // +class index_dense_gt { + public: + using vector_key_t = key_at; + using key_t = vector_key_t; + using compressed_slot_t = compressed_slot_at; + using distance_t = distance_punned_t; + using metric_t = metric_punned_t; + + using member_ref_t = member_ref_gt; + using member_cref_t = member_cref_gt; + + using head_t = index_dense_head_t; + using head_buffer_t = index_dense_head_buffer_t; + using head_result_t = index_dense_head_result_t; + + using serialization_config_t = index_dense_serialization_config_t; + + using dynamic_allocator_t = aligned_allocator_gt; + using tape_allocator_t = memory_mapping_allocator_gt<64>; + + private: + /// @brief Punned index. + using index_t = index_gt< // + distance_t, vector_key_t, compressed_slot_t, // + dynamic_allocator_t, tape_allocator_t>; + using index_allocator_t = aligned_allocator_gt; + + using member_iterator_t = typename index_t::member_iterator_t; + using member_citerator_t = typename index_t::member_citerator_t; + + /// @brief Punned metric object. + class metric_proxy_t { + index_dense_gt const* index_ = nullptr; + + public: + metric_proxy_t(index_dense_gt const& index) noexcept : index_(&index) {} + + inline distance_t operator()(byte_t const* a, member_cref_t b) const noexcept { return f(a, v(b)); } + inline distance_t operator()(member_cref_t a, member_cref_t b) const noexcept { return f(v(a), v(b)); } + + inline distance_t operator()(byte_t const* a, member_citerator_t b) const noexcept { return f(a, v(b)); } + inline distance_t operator()(member_citerator_t a, member_citerator_t b) const noexcept { + return f(v(a), v(b)); + } + + inline distance_t operator()(byte_t const* a, byte_t const* b) const noexcept { return f(a, b); } + + inline byte_t const* v(member_cref_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } + inline byte_t const* v(member_citerator_t m) const noexcept { return index_->vectors_lookup_[get_slot(m)]; } + inline distance_t f(byte_t const* a, byte_t const* b) const noexcept { return index_->metric_(a, b); } + }; + + index_dense_config_t config_; + index_t* typed_ = nullptr; + + using cast_buffer_t = buffer_gt; + + /// @brief Temporary memory for every thread to store a casted vector. + mutable cast_buffer_t cast_buffer_; + casts_punned_t casts_; + + /// @brief An instance of a potentially stateful `metric_t` used to initialize copies and forks. + metric_t metric_; + + using vectors_tape_allocator_t = memory_mapping_allocator_gt<8>; + /// @brief Allocator for the copied vectors, aligned to widest double-precision scalars. + vectors_tape_allocator_t vectors_tape_allocator_; + + using vectors_lookup_allocator_t = aligned_allocator_gt; + using vectors_lookup_t = buffer_gt; + + /// @brief For every managed `compressed_slot_t` stores a pointer to the allocated vector copy. + mutable vectors_lookup_t vectors_lookup_; + + using available_threads_allocator_t = aligned_allocator_gt; + using available_threads_t = ring_gt; + + /// @brief Originally forms and array of integers [0, threads], marking all as available. + mutable available_threads_t available_threads_; + + /// @brief Mutex, controlling concurrent access to `available_threads_`. + mutable std::mutex available_threads_mutex_; + +#if defined(USEARCH_DEFINED_CPP17) + using shared_mutex_t = std::shared_mutex; +#else + using shared_mutex_t = unfair_shared_mutex_t; +#endif + using shared_lock_t = shared_lock_gt; + using unique_lock_t = std::unique_lock; + + struct key_and_slot_t { + vector_key_t key; + compressed_slot_t slot; + + bool any_slot() const { return slot == default_free_value(); } + static key_and_slot_t any_slot(vector_key_t key) { return {key, default_free_value()}; } + }; + + struct lookup_key_hash_t { + using is_transparent = void; + std::size_t operator()(key_and_slot_t const& k) const noexcept { return hash_gt{}(k.key); } + std::size_t operator()(vector_key_t const& k) const noexcept { return hash_gt{}(k); } + }; + + struct lookup_key_same_t { + using is_transparent = void; + bool operator()(key_and_slot_t const& a, vector_key_t const& b) const noexcept { return a.key == b; } + bool operator()(vector_key_t const& a, key_and_slot_t const& b) const noexcept { return a == b.key; } + bool operator()(key_and_slot_t const& a, key_and_slot_t const& b) const noexcept { return a.key == b.key; } + }; + + /// @brief Multi-Map from keys to IDs, and allocated vectors. + flat_hash_multi_set_gt slot_lookup_; + + /// @brief Mutex, controlling concurrent access to `slot_lookup_`. + mutable shared_mutex_t slot_lookup_mutex_; + + /// @brief Ring-shaped queue of deleted entries, to be reused on future insertions. + ring_gt free_keys_; + + /// @brief Mutex, controlling concurrent access to `free_keys_`. + mutable std::mutex free_keys_mutex_; + + /// @brief A constant for the reserved key value, used to mark deleted entries. + vector_key_t free_key_ = default_free_value(); + + /// @brief Locks the thread for the duration of the operation. + struct thread_lock_t { + index_dense_gt const& parent; + std::size_t thread_id = 0; + bool engaged = false; + + ~thread_lock_t() usearch_noexcept_m { + if (engaged) + parent.thread_unlock_(thread_id); + } + + thread_lock_t(thread_lock_t const&) = delete; + thread_lock_t& operator=(thread_lock_t const&) = delete; + + thread_lock_t(index_dense_gt const& parent, std::size_t thread_id, bool engaged = true) noexcept + : parent(parent), thread_id(thread_id), engaged(engaged) {} + thread_lock_t(thread_lock_t&& other) noexcept + : parent(other.parent), thread_id(other.thread_id), engaged(other.engaged) { + other.engaged = false; + } + }; + + public: + using cluster_result_t = typename index_t::cluster_result_t; + using add_result_t = typename index_t::add_result_t; + using stats_t = typename index_t::stats_t; + using match_t = typename index_t::match_t; + + /** + * @brief A search result, containing the found keys and distances. + * + * As the `index_dense_gt` manages the thread-pool on its own, the search result + * preserves the thread-lock to avoid undefined behaviors, when other threads + * start overwriting the results. + */ + struct search_result_t : public index_t::search_result_t { + inline search_result_t(index_dense_gt const& parent) noexcept + : index_t::search_result_t(), lock_(parent, 0, false) {} + search_result_t failed(error_t message) noexcept { + this->error = std::move(message); + return std::move(*this); + } + + private: + friend class index_dense_gt; + thread_lock_t lock_; + + inline search_result_t(typename index_t::search_result_t result, thread_lock_t lock) noexcept + : index_t::search_result_t(std::move(result)), lock_(std::move(lock)) {} + }; + + index_dense_gt() = default; + index_dense_gt(index_dense_gt&& other) + : config_(std::move(other.config_)), + + typed_(exchange(other.typed_, nullptr)), // + cast_buffer_(std::move(other.cast_buffer_)), // + casts_(std::move(other.casts_)), // + metric_(std::move(other.metric_)), // + + vectors_tape_allocator_(std::move(other.vectors_tape_allocator_)), // + vectors_lookup_(std::move(other.vectors_lookup_)), // + + available_threads_(std::move(other.available_threads_)), // + slot_lookup_(std::move(other.slot_lookup_)), // + free_keys_(std::move(other.free_keys_)), // + free_key_(std::move(other.free_key_)) {} // + + index_dense_gt& operator=(index_dense_gt&& other) { + swap(other); + return *this; + } + + /** + * @brief Swaps the contents of this index with another index. + * @param other The other index to swap with. + */ + void swap(index_dense_gt& other) { + std::swap(config_, other.config_); + + std::swap(typed_, other.typed_); + std::swap(cast_buffer_, other.cast_buffer_); + std::swap(casts_, other.casts_); + std::swap(metric_, other.metric_); + + std::swap(vectors_tape_allocator_, other.vectors_tape_allocator_); + std::swap(vectors_lookup_, other.vectors_lookup_); + + std::swap(available_threads_, other.available_threads_); + std::swap(slot_lookup_, other.slot_lookup_); + std::swap(free_keys_, other.free_keys_); + std::swap(free_key_, other.free_key_); + } + + ~index_dense_gt() { + if (typed_) + typed_->~index_t(); + index_allocator_t{}.deallocate(typed_, 1); + typed_ = nullptr; + } + + struct state_result_t { + index_dense_gt index; + error_t error; + + explicit operator bool() const noexcept { return !error; } + state_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + operator index_dense_gt&&() && { + if (error) + usearch_raise_runtime_error(error.what()); + return std::move(index); + } + }; + using copy_result_t = state_result_t; + + /** + * @brief Constructs an instance of ::index_dense_gt. + * @param[in] metric One of the provided or an @b ad-hoc metric, type-punned. + * @param[in] config The index configuration (optional). + * @param[in] free_key The key used for freed vectors (optional). + * @return An instance of ::index_dense_gt or error, wrapped in a `state_result_t`. + * + * ! If the `metric` isn't provided in this method, it has to be set with + * ! the `change_metric` method before the index can be used. Alternatively, + * ! if you are loading an existing index, the metric will be set automatically. + */ + static state_result_t make( // + metric_t metric = {}, // + index_dense_config_t config = {}, // + vector_key_t free_key = default_free_value()) { + + if (metric.missing()) + return state_result_t{}.failed("Metric won't be initialized!"); + error_t error = config.validate(); + if (error) + return state_result_t{}.failed(std::move(error)); + index_t* raw = index_allocator_t{}.allocate(1); + if (!raw) + return state_result_t{}.failed("Failed to allocate memory for the index!"); + + state_result_t result; + index_dense_gt& index = result.index; + index.config_ = config; + index.free_key_ = free_key; + + // In some cases the metric is not provided, and will be set later. + if (metric) { + scalar_kind_t scalar_kind = metric.scalar_kind(); + index.casts_ = casts_punned_t::make(scalar_kind); + index.metric_ = metric; + } + + new (raw) index_t(config); + index.typed_ = raw; + return result; + } + + /** + * @brief Constructs an instance of ::index_dense_gt from a serialized binary file. + * @param[in] path The path to the binary file. + * @param[in] view Whether to map the file into memory or load it. + * @return An instance of ::index_dense_gt or error, wrapped in a `state_result_t`. + */ + static state_result_t make(char const* path, bool view = false) { + state_result_t result; + serialization_result_t serialization_result = view ? result.index.view(path) : result.index.load(path); + if (!serialization_result) + return result.failed(std::move(serialization_result.error)); + return result; + } + + explicit operator bool() const { return typed_; } + std::size_t connectivity() const { return typed_->connectivity(); } + std::size_t size() const { return typed_->size() - free_keys_.size(); } + std::size_t capacity() const { return typed_->capacity(); } + std::size_t max_level() const { return typed_->max_level(); } + index_dense_config_t const& config() const { return config_; } + index_limits_t const& limits() const { return typed_->limits(); } + double inverse_log_connectivity() const { return typed_->inverse_log_connectivity(); } + std::size_t neighbors_base_bytes() const { return typed_->neighbors_base_bytes(); } + std::size_t neighbors_bytes() const { return typed_->neighbors_bytes(); } + bool multi() const { return config_.multi; } + std::size_t currently_available_threads() const { + std::unique_lock available_threads_lock(available_threads_mutex_); + return available_threads_.size(); + } + + // The metric and its properties + metric_t const& metric() const { return metric_; } + void change_metric(metric_t metric) { metric_ = std::move(metric); } + + scalar_kind_t scalar_kind() const { return metric_.scalar_kind(); } + metric_kind_t metric_kind() const { return metric_.metric_kind(); } + std::size_t bytes_per_vector() const { return metric_.bytes_per_vector(); } + std::size_t scalar_words() const { return metric_.scalar_words(); } + std::size_t dimensions() const { return metric_.dimensions(); } + + // Fetching and changing search criteria + std::size_t expansion_add() const { return config_.expansion_add; } + std::size_t expansion_search() const { return config_.expansion_search; } + void change_expansion_add(std::size_t n) { config_.expansion_add = n; } + void change_expansion_search(std::size_t n) { config_.expansion_search = n; } + + member_citerator_t cbegin() const { return typed_->cbegin(); } + member_citerator_t cend() const { return typed_->cend(); } + member_iterator_t begin() { return typed_->begin(); } + member_iterator_t end() { return typed_->end(); } + + stats_t stats() const { return typed_->stats(); } + stats_t stats(std::size_t level) const { return typed_->stats(level); } + stats_t stats(stats_t* stats_per_level, std::size_t max_level) const { + return typed_->stats(stats_per_level, max_level); + } + + dynamic_allocator_t const& allocator() const { return typed_->dynamic_allocator(); } + vector_key_t const& free_key() const { return free_key_; } + + /** + * @brief A relatively accurate lower bound on the amount of memory consumed by the system. + * In practice it's error will be below 10%. + * + * @see `serialized_length` for the length of the binary serialized representation. + */ + std::size_t memory_usage() const { + return // + typed_->memory_usage(0) + // + typed_->tape_allocator().total_wasted() + // + typed_->tape_allocator().total_reserved() + // + vectors_tape_allocator_.total_allocated(); + } + + static constexpr std::size_t any_thread() { return std::numeric_limits::max(); } + static constexpr distance_t infinite_distance() { return std::numeric_limits::max(); } + + struct aggregated_distances_t { + std::size_t count = 0; + distance_t mean = infinite_distance(); + distance_t min = infinite_distance(); + distance_t max = infinite_distance(); + }; + + // clang-format off + add_result_t add(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread(), bool copy_vector = true) { return add_(key, vector, thread, copy_vector, casts_.from.b1x8); } + add_result_t add(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread(), bool copy_vector = true) { return add_(key, vector, thread, copy_vector, casts_.from.i8); } + add_result_t add(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread(), bool copy_vector = true) { return add_(key, vector, thread, copy_vector, casts_.from.f16); } + add_result_t add(vector_key_t key, bf16_t const* vector, std::size_t thread = any_thread(), bool copy_vector = true) { return add_(key, vector, thread, copy_vector, casts_.from.bf16); } + add_result_t add(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread(), bool copy_vector = true) { return add_(key, vector, thread, copy_vector, casts_.from.f32); } + add_result_t add(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread(), bool copy_vector = true) { return add_(key, vector, thread, copy_vector, casts_.from.f64); } + + search_result_t search(b1x8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.b1x8); } + search_result_t search(i8_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.i8); } + search_result_t search(f16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f16); } + search_result_t search(bf16_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.bf16); } + search_result_t search(f32_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f32); } + search_result_t search(f64_t const* vector, std::size_t wanted, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, dummy_predicate_t {}, thread, exact, casts_.from.f64); } + + template search_result_t filtered_search(b1x8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.b1x8); } + template search_result_t filtered_search(i8_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.i8); } + template search_result_t filtered_search(f16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.f16); } + template search_result_t filtered_search(bf16_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.bf16); } + template search_result_t filtered_search(f32_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.f32); } + template search_result_t filtered_search(f64_t const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread = any_thread(), bool exact = false) const { return search_(vector, wanted, std::forward(predicate), thread, exact, casts_.from.f64); } + + std::size_t get(vector_key_t key, b1x8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.b1x8); } + std::size_t get(vector_key_t key, i8_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.i8); } + std::size_t get(vector_key_t key, f16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f16); } + std::size_t get(vector_key_t key, bf16_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.bf16); } + std::size_t get(vector_key_t key, f32_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f32); } + std::size_t get(vector_key_t key, f64_t* vector, std::size_t vectors_count = 1) const { return get_(key, vector, vectors_count, casts_.to.f64); } + + cluster_result_t cluster(b1x8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.b1x8); } + cluster_result_t cluster(i8_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.i8); } + cluster_result_t cluster(f16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f16); } + cluster_result_t cluster(bf16_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.bf16); } + cluster_result_t cluster(f32_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f32); } + cluster_result_t cluster(f64_t const* vector, std::size_t level, std::size_t thread = any_thread()) const { return cluster_(vector, level, thread, casts_.from.f64); } + + aggregated_distances_t distance_between(vector_key_t key, b1x8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.b1x8); } + aggregated_distances_t distance_between(vector_key_t key, i8_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.i8); } + aggregated_distances_t distance_between(vector_key_t key, f16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f16); } + aggregated_distances_t distance_between(vector_key_t key, bf16_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.bf16); } + aggregated_distances_t distance_between(vector_key_t key, f32_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f32); } + aggregated_distances_t distance_between(vector_key_t key, f64_t const* vector, std::size_t thread = any_thread()) const { return distance_between_(key, vector, thread, casts_.to.f64); } + // clang-format on + + /** + * @brief Computes the distance between two managed entities. + * If either key maps into more than one vector, will aggregate results + * exporting the mean, maximum, and minimum values. + */ + aggregated_distances_t distance_between(vector_key_t a, vector_key_t b, std::size_t = any_thread()) const { + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled!"); + shared_lock_t lock(slot_lookup_mutex_); + aggregated_distances_t result; + if (!multi()) { + auto a_it = slot_lookup_.find(key_and_slot_t::any_slot(a)); + auto b_it = slot_lookup_.find(key_and_slot_t::any_slot(b)); + bool a_missing = a_it == slot_lookup_.end(); + bool b_missing = b_it == slot_lookup_.end(); + if (a_missing || b_missing) + return result; + + key_and_slot_t a_key_and_slot = *a_it; + byte_t const* a_vector = vectors_lookup_[a_key_and_slot.slot]; + key_and_slot_t b_key_and_slot = *b_it; + byte_t const* b_vector = vectors_lookup_[b_key_and_slot.slot]; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean = result.min = result.max = a_b_distance; + result.count = 1; + return result; + } + + auto a_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(a)); + auto b_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(b)); + bool a_missing = a_range.first == a_range.second; + bool b_missing = b_range.first == b_range.second; + if (a_missing || b_missing) + return result; + + result.min = std::numeric_limits::max(); + result.max = std::numeric_limits::min(); + result.mean = 0; + result.count = 0; + + while (a_range.first != a_range.second) { + key_and_slot_t a_key_and_slot = *a_range.first; + byte_t const* a_vector = vectors_lookup_[a_key_and_slot.slot]; + while (b_range.first != b_range.second) { + key_and_slot_t b_key_and_slot = *b_range.first; + byte_t const* b_vector = vectors_lookup_[b_key_and_slot.slot]; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean += a_b_distance; + result.min = (std::min)(result.min, a_b_distance); + result.max = (std::max)(result.max, a_b_distance); + result.count++; + + // + ++b_range.first; + } + ++a_range.first; + } + + result.mean /= result.count; + return result; + } + + /** + * @brief Identifies a node in a given `level`, that is the closest to the `key`. + */ + cluster_result_t cluster(vector_key_t key, std::size_t level, std::size_t thread = any_thread()) const { + + // Check if such `key` is even present. + shared_lock_t slots_lock(slot_lookup_mutex_); + auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + cluster_result_t result; + if (key_range.first == key_range.second) + return result.failed("Key missing!"); + + index_cluster_config_t cluster_config; + thread_lock_t lock = thread_lock_(thread); + cluster_config.thread = lock.thread_id; + cluster_config.expansion = config_.expansion_search; + metric_proxy_t metric{*this}; + vector_key_t free_key_copy = free_key_; + auto allow = [free_key_copy](member_cref_t const& member) noexcept { return member.key != free_key_copy; }; + + // Find the closest cluster for any vector under that key. + while (key_range.first != key_range.second) { + key_and_slot_t key_and_slot = *key_range.first; + byte_t const* vector_data = vectors_lookup_[key_and_slot.slot]; + cluster_result_t new_result = typed_->cluster(vector_data, level, metric, cluster_config, allow); + if (!new_result) + return new_result; + if (new_result.cluster.distance < result.cluster.distance) + result = std::move(new_result); + + ++key_range.first; + } + return result; + } + + /** + * @brief Reserves memory for the index and the keyed lookup. + * @return `true` if the memory reservation was successful, `false` otherwise. + * + * ! No update or search operations should be running during this operation. + */ + bool try_reserve(index_limits_t limits) { + + // The slot lookup system will generally prefer power-of-two sizes. + if (config_.enable_key_lookups) { + unique_lock_t lock(slot_lookup_mutex_); + if (!slot_lookup_.try_reserve(limits.members)) + return false; + limits.members = slot_lookup_.capacity(); + } + + // Once the `slot_lookup_` grows, let's use its capacity as the new + // target for the `vectors_lookup_` to synchronize allocations and + // expensive index re-organizations. + if (limits.members != vectors_lookup_.size()) { + vectors_lookup_t new_vectors_lookup(limits.members); + if (!new_vectors_lookup) + return false; + if (vectors_lookup_.size() > 0) + std::memcpy(new_vectors_lookup.data(), vectors_lookup_.data(), + vectors_lookup_.size() * sizeof(byte_t*)); + vectors_lookup_ = std::move(new_vectors_lookup); + } + + // During reserve, no insertions may be happening, so we can safely overwrite the whole collection. + std::unique_lock available_threads_lock(available_threads_mutex_); + available_threads_.clear(); + if (!available_threads_.reserve(limits.threads())) + return false; + for (std::size_t i = 0; i < limits.threads(); i++) + available_threads_.push(i); + + // Allocate a buffer for the casted vectors. + cast_buffer_t cast_buffer(limits.threads() * metric_.bytes_per_vector()); + if (!cast_buffer) + return false; + cast_buffer_ = std::move(cast_buffer); + + return typed_->reserve(limits); + } + + void reserve(index_limits_t limits) { + if (!try_reserve(limits)) + usearch_raise_runtime_error("failed to reserve memory"); + } + + /** + * @brief Erases all the vectors from the index. + * + * Will change `size()` to zero, but will keep the same `capacity()`. + * Will keep the number of available threads/contexts the same as it was. + */ + void clear() { + unique_lock_t lookup_lock(slot_lookup_mutex_); + + std::unique_lock free_lock(free_keys_mutex_); + typed_->clear(); + slot_lookup_.clear(); + vectors_lookup_.reset(); + free_keys_.clear(); + vectors_tape_allocator_.reset(); + } + + /** + * @brief Erases all members from index, closing files, and returning RAM to OS. + * + * Will change both `size()` and `capacity()` to zero. + * Will deallocate all threads/contexts. + * If the index is memory-mapped - releases the mapping and the descriptor. + */ + void reset() { + + unique_lock_t lookup_lock(slot_lookup_mutex_); + std::unique_lock free_lock(free_keys_mutex_); + std::unique_lock available_threads_lock(available_threads_mutex_); + + if (typed_) + typed_->reset(); + slot_lookup_.clear(); + vectors_lookup_.reset(); + free_keys_.clear(); + vectors_tape_allocator_.reset(); + available_threads_.reset(); + } + + /** + * @brief Saves serialized binary index representation to a stream. + */ + template + serialization_result_t save_to_stream(output_callback_at&& output, // + serialization_config_t config = {}, // + progress_at&& progress = {}) const { + + serialization_result_t result; + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + + // We may not want to put the vectors into the same file + if (!config.exclude_vectors) { + // Save the matrix size + if (!config.use_64_bit_dimensions) { + std::uint32_t dimensions[2]; + dimensions[0] = static_cast(typed_->size()); + dimensions[1] = static_cast(metric_.bytes_per_vector()); + if (!output(&dimensions, sizeof(dimensions))) + return result.failed("Failed to serialize into stream"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } else { + std::uint64_t dimensions[2]; + dimensions[0] = static_cast(typed_->size()); + dimensions[1] = static_cast(metric_.bytes_per_vector()); + if (!output(&dimensions, sizeof(dimensions))) + return result.failed("Failed to serialize into stream"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } + + // Dump the vectors one after another + for (std::uint64_t i = 0; i != matrix_rows; ++i) { + byte_t* vector = vectors_lookup_[i]; + if (!output(vector, matrix_cols)) + return result.failed("Failed to serialize into stream"); + } + } + + // Augment metadata + { + index_dense_head_buffer_t buffer; + std::memset(buffer, 0, sizeof(buffer)); + index_dense_head_t head{buffer}; + std::memcpy(buffer, default_magic(), std::strlen(default_magic())); + + // Describe software version + using version_t = index_dense_head_t::version_t; + head.version_major = static_cast(USEARCH_VERSION_MAJOR); + head.version_minor = static_cast(USEARCH_VERSION_MINOR); + head.version_patch = static_cast(USEARCH_VERSION_PATCH); + + // Describes types used + head.kind_metric = metric_.metric_kind(); + head.kind_scalar = metric_.scalar_kind(); + head.kind_key = unum::usearch::scalar_kind(); + head.kind_compressed_slot = unum::usearch::scalar_kind(); + + head.count_present = size(); + head.count_deleted = typed_->size() - size(); + head.dimensions = dimensions(); + head.multi = multi(); + + if (!output(&buffer, sizeof(buffer))) + return result.failed("Failed to serialize into stream"); + } + + // Save the actual proximity graph + return typed_->save_to_stream(std::forward(output), std::forward(progress)); + } + + /** + * @brief Estimate the binary length (in bytes) of the serialized index. + */ + std::size_t serialized_length(serialization_config_t config = {}) const { + std::size_t dimensions_length = 0; + std::size_t matrix_length = 0; + if (!config.exclude_vectors) { + dimensions_length = config.use_64_bit_dimensions ? sizeof(std::uint64_t) * 2 : sizeof(std::uint32_t) * 2; + matrix_length = typed_->size() * metric_.bytes_per_vector(); + } + return dimensions_length + matrix_length + sizeof(index_dense_head_buffer_t) + typed_->serialized_length(); + } + + /** + * @brief Parses the index from file to RAM. + * @param[in] input The input stream to read from. + * @param[in] config Configuration parameters for imports. + * @param[in] progress Callback to report the execution progress. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t load_from_stream(input_callback_at&& input, // + serialization_config_t config = {}, // + progress_at&& progress = {}) { + + // Discard all previous memory allocations of `vectors_tape_allocator_` + index_limits_t old_limits = typed_ ? typed_->limits() : index_limits_t{}; + reset(); + + // Infer the new index size + serialization_result_t result; + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + + // We may not want to load the vectors from the same file, or allow attaching them afterwards + if (!config.exclude_vectors) { + // Save the matrix size + if (!config.use_64_bit_dimensions) { + std::uint32_t dimensions[2]; + if (!input(&dimensions, sizeof(dimensions))) + return result.failed("Failed to read 32-bit dimensions of the matrix"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } else { + std::uint64_t dimensions[2]; + if (!input(&dimensions, sizeof(dimensions))) + return result.failed("Failed to read 64-bit dimensions of the matrix"); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + } + // Load the vectors one after another + vectors_lookup_ = vectors_lookup_t(matrix_rows); + if (!vectors_lookup_) + return result.failed("Failed to allocate memory to address vectors"); + for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) { + byte_t* vector = vectors_tape_allocator_.allocate(matrix_cols); + if (!input(vector, matrix_cols)) + return result.failed("Failed to read vectors"); + vectors_lookup_[slot] = vector; + } + } + + // Load metadata and choose the right metric + { + index_dense_head_buffer_t buffer; + if (!input(buffer, sizeof(buffer))) + return result.failed("Failed to read the index "); + + index_dense_head_t head{buffer}; + if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) + return result.failed("Magic header mismatch - the file isn't an index"); + + // fix pre-2.10 headers + fix_pre_2_10_metadata(head); + + // Validate the software version + if (head.version_major != USEARCH_VERSION_MAJOR) + return result.failed("File format may be different, please rebuild"); + + // Check the types used + if (head.kind_key != unum::usearch::scalar_kind()) + return result.failed("Key type doesn't match, consider rebuilding"); + if (head.kind_compressed_slot != unum::usearch::scalar_kind()) + return result.failed("Slot type doesn't match, consider rebuilding"); + + config_.multi = head.multi; + metric_ = metric_t::builtin(head.dimensions, head.kind_metric, head.kind_scalar); + // available_threads_.size() will be updated to old_limits.threads() later in this + // method, so use that as the number of threads to prepare for. + cast_buffer_ = cast_buffer_t(old_limits.threads() * metric_.bytes_per_vector()); + if (!cast_buffer_) + return result.failed("Failed to allocate memory for the casts"); + casts_ = casts_punned_t::make(head.kind_scalar); + } + + // Pull the actual proximity graph + if (!typed_) { + index_t* raw = index_allocator_t{}.allocate(1); + if (!raw) + return result.failed("Failed to allocate memory for the index"); + new (raw) index_t(config_); + typed_ = raw; + } + result = typed_->load_from_stream(std::forward(input), std::forward(progress)); + if (!result) + return result; + if (typed_->size() != static_cast(matrix_rows)) + return result.failed("Index size and the number of vectors doesn't match"); + old_limits.members = static_cast(matrix_rows); + if (!typed_->try_reserve(old_limits)) + return result.failed("Failed to reserve memory for the index"); + + // After the index is loaded, we may have to resize the `available_threads_` to + // match the limits of the underlying engine. + available_threads_t available_threads; + std::size_t max_threads = old_limits.threads(); + if (!available_threads.reserve(max_threads)) + return result.failed("Failed to allocate memory for the available threads!"); + for (std::size_t i = 0; i < max_threads; i++) + available_threads.push(i); + available_threads_ = std::move(available_threads); + + reindex_keys_(); + return result; + } + + /** + * @brief Parses the index from file, without loading it into RAM. + * @param[in] file The input file to read from. + * @param[in] offset The offset in the file to start reading from. + * @param[in] config Configuration parameters for imports. + * @param[in] progress Callback to report the execution progress. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t view(memory_mapped_file_t file, // + std::size_t offset = 0, serialization_config_t config = {}, // + progress_at&& progress = {}) { + + // Discard all previous memory allocations of `vectors_tape_allocator_` + index_limits_t old_limits = typed_ ? typed_->limits() : index_limits_t{}; + reset(); + + serialization_result_t result = file.open_if_not(); + if (!result) + return result; + + // Infer the new index size + std::uint64_t matrix_rows = 0; + std::uint64_t matrix_cols = 0; + span_punned_t vectors_buffer; + + // We may not want to fetch the vectors from the same file, or allow attaching them afterwards + if (!config.exclude_vectors) { + // Save the matrix size + if (!config.use_64_bit_dimensions) { + std::uint32_t dimensions[2]; + if (file.size() - offset < sizeof(dimensions)) + return result.failed("File is corrupted and lacks matrix dimensions"); + std::memcpy(&dimensions, file.data() + offset, sizeof(dimensions)); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + offset += sizeof(dimensions); + } else { + std::uint64_t dimensions[2]; + if (file.size() - offset < sizeof(dimensions)) + return result.failed("File is corrupted and lacks matrix dimensions"); + std::memcpy(&dimensions, file.data() + offset, sizeof(dimensions)); + matrix_rows = dimensions[0]; + matrix_cols = dimensions[1]; + offset += sizeof(dimensions); + } + vectors_buffer = {file.data() + offset, static_cast(matrix_rows * matrix_cols)}; + offset += vectors_buffer.size(); + } + + // Load metadata and choose the right metric + { + index_dense_head_buffer_t buffer; + if (file.size() - offset < sizeof(buffer)) + return result.failed("File is corrupted and lacks a header"); + + std::memcpy(buffer, file.data() + offset, sizeof(buffer)); + + index_dense_head_t head{buffer}; + if (std::memcmp(buffer, default_magic(), std::strlen(default_magic())) != 0) + return result.failed("Magic header mismatch - the file isn't an index"); + + // fix pre-2.10 headers + fix_pre_2_10_metadata(head); + + // Validate the software version + if (head.version_major != USEARCH_VERSION_MAJOR) + return result.failed("File format may be different, please rebuild"); + + // Check the types used + if (head.kind_key != unum::usearch::scalar_kind()) + return result.failed("Key type doesn't match, consider rebuilding"); + if (head.kind_compressed_slot != unum::usearch::scalar_kind()) + return result.failed("Slot type doesn't match, consider rebuilding"); + + config_.multi = head.multi; + metric_ = metric_t::builtin(head.dimensions, head.kind_metric, head.kind_scalar); + // available_threads_.size() will be updated to old_limits.threads() later in this + // method, so use that as the number of threads to prepare for. + cast_buffer_ = cast_buffer_t(old_limits.threads() * metric_.bytes_per_vector()); + if (!cast_buffer_) + return result.failed("Failed to allocate memory for the casts"); + casts_ = casts_punned_t::make(head.kind_scalar); + offset += sizeof(buffer); + } + + // Pull the actual proximity graph + if (!typed_) { + index_t* raw = index_allocator_t{}.allocate(1); + if (!raw) + return result.failed("Failed to allocate memory for the index"); + new (raw) index_t(config_); + typed_ = raw; + } + result = typed_->view(std::move(file), offset, std::forward(progress)); + if (!result) + return result; + if (typed_->size() != static_cast(matrix_rows)) + return result.failed("Index size and the number of vectors doesn't match"); + old_limits.members = static_cast(matrix_rows); + if (!typed_->try_reserve(old_limits)) + return result.failed("Failed to reserve memory for the index"); + + // Address the vectors + vectors_lookup_ = vectors_lookup_t(matrix_rows); + if (!vectors_lookup_) + return result.failed("Failed to allocate memory to address vectors"); + if (!config.exclude_vectors) + for (std::uint64_t slot = 0; slot != matrix_rows; ++slot) + vectors_lookup_[slot] = (byte_t*)vectors_buffer.data() + matrix_cols * slot; + + // After the index is loaded, we may have to resize the `available_threads_` to + // match the limits of the underlying engine. + available_threads_t available_threads; + std::size_t max_threads = old_limits.threads(); + if (!available_threads.reserve(max_threads)) + return result.failed("Failed to allocate memory for the available threads!"); + for (std::size_t i = 0; i < max_threads; i++) + available_threads.push(i); + available_threads_ = std::move(available_threads); + + reindex_keys_(); + return result; + } + + /** + * @brief Saves the index to a file. + * @param[in] file The output file to write to. + * @param[in] config Configuration parameters for exports. + * @param[in] progress Callback to report the execution progress. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t save(output_file_t file, serialization_config_t config = {}, + progress_at&& progress = {}) const { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void const* buffer, std::size_t length) { + io_result = file.write(buffer, length); + return !!io_result; + }, + config, std::forward(progress)); + + if (!stream_result) { + io_result.error.release(); + return stream_result; + } + return io_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t save(memory_mapped_file_t file, // + std::size_t offset = 0, // + serialization_config_t config = {}, // + progress_at&& progress = {}) const { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = save_to_stream( + [&](void const* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(file.data() + offset, buffer, length); + offset += length; + return true; + }, + config, std::forward(progress)); + + return stream_result; + } + + /** + * @brief Parses the index from file to RAM. + * @param[in] file The input file to read from. + * @param[in] config Configuration parameters for imports. + * @param[in] progress Progress callback. + * @return Outcome descriptor explicitly convertible to boolean. + */ + template + serialization_result_t load(input_file_t file, serialization_config_t config = {}, progress_at&& progress = {}) { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + io_result = file.read(buffer, length); + return !!io_result; + }, + config, std::forward(progress)); + + if (!stream_result) { + io_result.error.release(); + return stream_result; + } + return io_result; + } + + /** + * @brief Memory-maps the serialized binary index representation from disk, + * @b without copying data into RAM, and fetching it on-demand. + */ + template + serialization_result_t load(memory_mapped_file_t file, // + std::size_t offset = 0, // + serialization_config_t config = {}, // + progress_at&& progress = {}) { + + serialization_result_t io_result = file.open_if_not(); + if (!io_result) + return io_result; + + serialization_result_t stream_result = load_from_stream( + [&](void* buffer, std::size_t length) { + if (offset + length > file.size()) + return false; + std::memcpy(buffer, file.data() + offset, length); + offset += length; + return true; + }, + config, std::forward(progress)); + + return stream_result; + } + + template + serialization_result_t save(char const* file_path, // + serialization_config_t config = {}, // + progress_at&& progress = {}) const { + return save(output_file_t(file_path), config, std::forward(progress)); + } + + template + serialization_result_t load(char const* file_path, // + serialization_config_t config = {}, // + progress_at&& progress = {}) { + return load(input_file_t(file_path), config, std::forward(progress)); + } + + /** + * @brief Checks if a vector with specified key is present. + * @return `true` if the key is present in the index, `false` otherwise. + */ + bool contains(vector_key_t key) const { + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled"); + shared_lock_t lock(slot_lookup_mutex_); + return slot_lookup_.contains(key_and_slot_t::any_slot(key)); + } + + /** + * @brief Count the number of vectors with specified key present. + * @return Zero if nothing is found, a positive integer otherwise. + */ + std::size_t count(vector_key_t key) const { + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled"); + shared_lock_t lock(slot_lookup_mutex_); + return slot_lookup_.count(key_and_slot_t::any_slot(key)); + } + + struct labeling_result_t { + error_t error{}; + std::size_t completed{}; + + explicit operator bool() const noexcept { return !error; } + labeling_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Removes an entry with the specified key from the index. + * @param[in] key The key of the entry to remove. + * @return The ::labeling_result_t indicating the result of the removal operation. + * If the removal was successful, `result.completed` will be `true`. + * If the key was not found in the index, `result.completed` will be `false`. + * If an error occurred during the removal operation, `result.error` will contain an error message. + */ + labeling_result_t remove(vector_key_t key) { + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled"); + labeling_result_t result; + if (typed_->is_immutable()) + return result.failed("Can't remove from an immutable index"); + + unique_lock_t lookup_lock(slot_lookup_mutex_); + auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + if (matching_slots.first == matching_slots.second) + return result; + + // Grow the removed entries ring, if needed + std::size_t matching_count = std::distance(matching_slots.first, matching_slots.second); + std::unique_lock free_lock(free_keys_mutex_); + std::size_t free_count_old = free_keys_.size(); + if (!free_keys_.reserve(free_count_old + matching_count)) + return result.failed("Can't allocate memory for a free-list"); + + // A removed entry would be: + // - present in `free_keys_` + // - missing in the `slot_lookup_` + // - marked in the `typed_` index with a `free_key_` + for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) { + compressed_slot_t slot = (*slots_it).slot; + free_keys_.push(slot); + typed_->at(slot).key = free_key_; + } + slot_lookup_.erase(key); + result.completed = matching_count; + usearch_assert_m(free_keys_.size() == free_count_old + matching_count, "Free keys count mismatch"); + + return result; + } + + /** + * @brief Removes multiple entries with the specified keys from the index. + * @param[in] keys_begin The beginning of the keys range. + * @param[in] keys_end The ending of the keys range. + * @return The ::labeling_result_t indicating the result of the removal operation. + * `result.completed` will contain the number of keys that were successfully removed. + * `result.error` will contain an error message if an error occurred during the removal operation. + */ + template + labeling_result_t remove(keys_iterator_at keys_begin, keys_iterator_at keys_end) { + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled"); + + labeling_result_t result; + unique_lock_t lookup_lock(slot_lookup_mutex_); + std::unique_lock free_lock(free_keys_mutex_); + // Grow the removed entries ring, if needed + std::size_t matching_count = 0; + for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) + matching_count += slot_lookup_.count(key_and_slot_t::any_slot(*keys_it)); + + if (!free_keys_.reserve(free_keys_.size() + matching_count)) + return result.failed("Can't allocate memory for a free-list"); + + // Remove them one-by-one + for (auto keys_it = keys_begin; keys_it != keys_end; ++keys_it) { + vector_key_t key = *keys_it; + auto matching_slots = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + // A removed entry would be: + // - present in `free_keys_` + // - missing in the `slot_lookup_` + // - marked in the `typed_` index with a `free_key_` + matching_count = 0; + for (auto slots_it = matching_slots.first; slots_it != matching_slots.second; ++slots_it) { + compressed_slot_t slot = (*slots_it).slot; + free_keys_.push(slot); + typed_->at(slot).key = free_key_; + ++matching_count; + } + + slot_lookup_.erase(key); + result.completed += matching_count; + } + + return result; + } + + /** + * @brief Renames an entry with the specified key to a new key. + * @param[in] from The current key of the entry to rename. + * @param[in] to The new key to assign to the entry. + * @return The ::labeling_result_t indicating the result of the rename operation. + * If the rename was successful, `result.completed` will be `true`. + * If the entry with the current key was not found, `result.completed` will be `false`. + */ + labeling_result_t rename(vector_key_t from, vector_key_t to) { + labeling_result_t result; + unique_lock_t lookup_lock(slot_lookup_mutex_); + + if (!multi() && slot_lookup_.contains(key_and_slot_t::any_slot(to))) + return result.failed("Renaming impossible, the key is already in use"); + + // The `from` may map to multiple entries + while (true) { + key_and_slot_t key_and_slot_removed; + if (!slot_lookup_.pop_first(key_and_slot_t::any_slot(from), key_and_slot_removed)) + break; + + key_and_slot_t key_and_slot_replacing{to, key_and_slot_removed.slot}; + slot_lookup_.try_emplace(key_and_slot_replacing); // This can't fail + typed_->at(key_and_slot_removed.slot).key = to; + ++result.completed; + } + + return result; + } + + /** + * @brief Exports a range of keys for the vectors present in the index. + * @param[out] keys Pointer to the array where the keys will be exported. + * @param[in] offset The number of keys to skip. Useful for pagination. + * @param[in] limit The maximum number of keys to export, that can fit in ::keys. + */ + void export_keys(vector_key_t* keys, std::size_t offset, std::size_t limit) const { + shared_lock_t lock(slot_lookup_mutex_); + offset = (std::min)(offset, slot_lookup_.size()); + slot_lookup_.for_each([&](key_and_slot_t const& key_and_slot) { + if (offset) + // Skip the first `offset` entries + --offset; + else if (limit) { + *keys = key_and_slot.key; + ++keys; + --limit; + } + }); + } + + /** + * @brief Copies the ::index_dense_gt @b with all the data in it. + * @param config The copy configuration (optional). + * @return A copy of the ::index_dense_gt instance. + */ + copy_result_t copy(index_dense_copy_config_t config = {}) const { + copy_result_t result = fork(); + if (!result) + return result; + + auto typed_result = typed_->copy(config); + if (!typed_result) + return result.failed(std::move(typed_result.error)); + + // Export the free (removed) slot numbers + index_dense_gt& copy = result.index; + if (!copy.free_keys_.reserve(free_keys_.size())) + return result.failed(std::move(typed_result.error)); + for (std::size_t i = 0; i != free_keys_.size(); ++i) + copy.free_keys_.push(free_keys_[i]); + + // Allocate buffers and move the vectors themselves + copy.vectors_lookup_ = vectors_lookup_t(vectors_lookup_.size()); + if (!copy.vectors_lookup_) + return result.failed("Out of memory!"); + if (!config.force_vector_copy && copy.config_.exclude_vectors) { + std::memcpy(copy.vectors_lookup_.data(), vectors_lookup_.data(), vectors_lookup_.size() * sizeof(byte_t*)); + } else { + std::size_t slots_count = typed_result.index.size(); + for (std::size_t slot = 0; slot != slots_count; ++slot) + copy.vectors_lookup_[slot] = copy.vectors_tape_allocator_.allocate(copy.metric_.bytes_per_vector()); + if (std::count(copy.vectors_lookup_.begin(), copy.vectors_lookup_.begin() + slots_count, nullptr)) + return result.failed("Out of memory!"); + for (std::size_t slot = 0; slot != slots_count; ++slot) + std::memcpy(copy.vectors_lookup_[slot], vectors_lookup_[slot], metric_.bytes_per_vector()); + } + + copy.slot_lookup_ = slot_lookup_; // TODO: Handle out of memory + *copy.typed_ = std::move(typed_result.index); + return result; + } + + /** + * @brief Copies the ::index_dense_gt model @b without any data. + * @return A similarly configured ::index_dense_gt instance. + */ + copy_result_t fork() const { + + cast_buffer_t cast_buffer(cast_buffer_.size()); + if (!cast_buffer) + return state_result_t{}.failed("Failed to allocate memory for the casts!"); + available_threads_t available_threads; + std::size_t max_threads = limits().threads(); + if (!available_threads.reserve(max_threads)) + return state_result_t{}.failed("Failed to allocate memory for the available threads!"); + for (std::size_t i = 0; i < max_threads; i++) + available_threads.push(i); + index_t* raw = index_allocator_t{}.allocate(1); + if (!raw) + return state_result_t{}.failed("Failed to allocate memory for the index!"); + + copy_result_t result; + index_dense_gt& other = result.index; + index_limits_t other_limits = limits(); + other_limits.members = 0; + other.config_ = config_; + other.cast_buffer_ = std::move(cast_buffer); + other.casts_ = casts_; + + other.metric_ = metric_; + other.available_threads_ = std::move(available_threads); + other.free_key_ = free_key_; + + new (raw) index_t(config()); + raw->try_reserve(other_limits); + other.typed_ = raw; + return result; + } + + struct compaction_result_t { + error_t error{}; + std::size_t pruned_edges{}; + + explicit operator bool() const noexcept { return !error; } + compaction_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Performs compaction on the index, pruning links to removed entries. + * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. + * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. + * @return The ::compaction_result_t indicating the result of the compaction operation. + * `result.pruned_edges` will contain the number of edges that were removed. + * `result.error` will contain an error message if an error occurred during the compaction operation. + */ + template + compaction_result_t isolate(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + compaction_result_t result; + std::atomic pruned_edges; + auto allow = [&](member_cref_t const& member) noexcept { + bool freed = member.key == free_key_; + pruned_edges += freed; + return !freed; + }; + typed_->isolate(allow, std::forward(executor), std::forward(progress)); + result.pruned_edges = pruned_edges; + return result; + } + + class values_proxy_t { + index_dense_gt const* index_; + + public: + values_proxy_t(index_dense_gt const& index) noexcept : index_(&index) {} + byte_t const* operator[](compressed_slot_t slot) const noexcept { return index_->vectors_lookup_[slot]; } + byte_t const* operator[](member_citerator_t it) const noexcept { return index_->vectors_lookup_[get_slot(it)]; } + }; + + /** + * @brief Performs compaction on the index, pruning links to removed entries. + * @param executor The executor parallel processing. Default ::dummy_executor_t single-threaded. + * @param progress The progress tracker instance to use. Default ::dummy_progress_t reports nothing. + * @return The ::compaction_result_t indicating the result of the compaction operation. + * `result.pruned_edges` will contain the number of edges that were removed. + * `result.error` will contain an error message if an error occurred during the compaction operation. + */ + template + compaction_result_t compact(executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + compaction_result_t result; + + vectors_lookup_t new_vectors_lookup(vectors_lookup_.size()); + if (!new_vectors_lookup) + return result.failed("Out of memory!"); + + vectors_tape_allocator_t new_vectors_allocator; + + auto track_slot_change = [&](vector_key_t, compressed_slot_t old_slot, compressed_slot_t new_slot) { + byte_t* new_vector = new_vectors_allocator.allocate(metric_.bytes_per_vector()); + byte_t* old_vector = vectors_lookup_[old_slot]; + std::memcpy(new_vector, old_vector, metric_.bytes_per_vector()); + new_vectors_lookup[new_slot] = new_vector; + }; + typed_->compact(values_proxy_t{*this}, metric_proxy_t{*this}, track_slot_change, + std::forward(executor), std::forward(progress)); + vectors_lookup_ = std::move(new_vectors_lookup); + vectors_tape_allocator_ = std::move(new_vectors_allocator); + return result; + } + + template < // + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + join_result_t join( // + index_dense_gt const& women, // + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) const { + + index_dense_gt const& men = *this; + return unum::usearch::join( // + *men.typed_, *women.typed_, // + values_proxy_t{men}, values_proxy_t{women}, // + metric_proxy_t{men}, metric_proxy_t{women}, // + config, // + std::forward(man_to_woman), // + std::forward(woman_to_man), // + std::forward(executor), // + std::forward(progress)); + } + + struct clustering_result_t { + error_t error{}; + std::size_t clusters{}; + std::size_t visited_members{}; + std::size_t computed_distances{}; + + explicit operator bool() const noexcept { return !error; } + clustering_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } + }; + + /** + * @brief Implements clustering, classifying the given objects (vectors of member keys) + * into a given number of clusters. + * + * @param[in] queries_begin Iterator pointing to the first query. + * @param[in] queries_end Iterator pointing to the last query. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + * @param[in] config Configuration parameters for clustering. + * + * @param[out] cluster_keys Pointer to the array where the cluster keys will be exported. + * @param[out] cluster_distances Pointer to the array where the distances to those centroids will be exported. + */ + template < // + typename queries_iterator_at, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > + clustering_result_t cluster( // + queries_iterator_at queries_begin, // + queries_iterator_at queries_end, // + index_dense_clustering_config_t config, // + vector_key_t* cluster_keys, // + distance_t* cluster_distances, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) { + + std::size_t const queries_count = queries_end - queries_begin; + + // Find the first level (top -> down) that has enough nodes to exceed `config.min_clusters`. + std::size_t level = max_level(); + if (config.min_clusters) { + for (; level > 1; --level) { + if (stats(level).nodes > config.min_clusters) + break; + } + } else + level = 1, config.max_clusters = stats(1).nodes, config.min_clusters = 2; + + clustering_result_t result; + if (max_level() < 2) + return result.failed("Index too small to cluster!"); + + // A structure used to track the popularity of a specific cluster + struct cluster_t { + vector_key_t centroid; + vector_key_t merged_into; + std::size_t popularity; + byte_t* vector; + }; + + auto centroid_id = [](cluster_t const& a, cluster_t const& b) { return a.centroid < b.centroid; }; + auto higher_popularity = [](cluster_t const& a, cluster_t const& b) { return a.popularity > b.popularity; }; + + std::atomic visited_members(0); + std::atomic computed_distances(0); + std::atomic atomic_error{nullptr}; + + using dynamic_allocator_traits_t = std::allocator_traits; + using clusters_allocator_t = typename dynamic_allocator_traits_t::template rebind_alloc; + buffer_gt clusters(queries_count); + if (!clusters) + return result.failed("Out of memory!"); + + map_to_clusters: + // Concurrently perform search until a certain depth + executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { + auto result = cluster(queries_begin[query_idx], level, thread_idx); + if (!result) { + atomic_error = result.error.release(); + return false; + } + + cluster_keys[query_idx] = result.cluster.member.key; + cluster_distances[query_idx] = result.cluster.distance; + + // Export in case we need to refine afterwards + clusters[query_idx].centroid = result.cluster.member.key; + clusters[query_idx].vector = vectors_lookup_[result.cluster.member.slot]; + clusters[query_idx].merged_into = free_key(); + clusters[query_idx].popularity = 1; + + visited_members += result.visited_members; + computed_distances += result.computed_distances; + return true; + }); + + if (atomic_error) + return result.failed(atomic_error.load()); + + // Now once we have identified the closest clusters, + // we can try reducing their quantity, refining + std::sort(clusters.begin(), clusters.end(), centroid_id); + + // Transform into run-length encoding, computing the number of unique clusters + std::size_t unique_clusters = 0; + { + std::size_t last_idx = 0; + for (std::size_t current_idx = 1; current_idx != clusters.size(); ++current_idx) { + if (clusters[last_idx].centroid == clusters[current_idx].centroid) { + clusters[last_idx].popularity++; + } else { + last_idx++; + clusters[last_idx] = clusters[current_idx]; + } + } + unique_clusters = last_idx + 1; + } + + // In some cases the queries may be co-located, all mapping into the same cluster on that + // level. In that case we refine the granularity and dive deeper into clusters: + if (unique_clusters < config.min_clusters && level > 1) { + level--; + goto map_to_clusters; + } + + std::sort(clusters.data(), clusters.data() + unique_clusters, higher_popularity); + + // If clusters are too numerous, merge the ones that are too close to each other. + std::size_t merge_cycles = 0; + merge_nearby_clusters: + if (unique_clusters > config.max_clusters) { + + cluster_t& merge_source = clusters[unique_clusters - 1]; + std::size_t merge_target_idx = 0; + distance_t merge_distance = std::numeric_limits::max(); + + for (std::size_t candidate_idx = 0; candidate_idx + 1 < unique_clusters; ++candidate_idx) { + distance_t distance = metric_(merge_source.vector, clusters[candidate_idx].vector); + if (distance < merge_distance) { + merge_distance = distance; + merge_target_idx = candidate_idx; + } + } + + merge_source.merged_into = clusters[merge_target_idx].centroid; + clusters[merge_target_idx].popularity += exchange(merge_source.popularity, 0); + + // The target object may have to be swapped a few times to get to optimal position. + while (merge_target_idx && + clusters[merge_target_idx - 1].popularity < clusters[merge_target_idx].popularity) + std::swap(clusters[merge_target_idx - 1], clusters[merge_target_idx]), --merge_target_idx; + + unique_clusters--; + merge_cycles++; + goto merge_nearby_clusters; + } + + // Replace evicted clusters + if (merge_cycles) { + // Sort dropped clusters by name to accelerate future lookups + auto clusters_end = clusters.data() + config.max_clusters + merge_cycles; + std::sort(clusters.data(), clusters_end, centroid_id); + + executor.dynamic(queries_count, [&](std::size_t thread_idx, std::size_t query_idx) { + vector_key_t& cluster_key = cluster_keys[query_idx]; + distance_t& cluster_distance = cluster_distances[query_idx]; + + // Recursively trace replacements of that cluster + while (true) { + // To avoid implementing heterogeneous comparisons, lets wrap the `cluster_key` + cluster_t updated_cluster; + updated_cluster.centroid = cluster_key; + updated_cluster = *std::lower_bound(clusters.data(), clusters_end, updated_cluster, centroid_id); + if (updated_cluster.merged_into == free_key()) + break; + cluster_key = updated_cluster.merged_into; + } + + cluster_distance = distance_between(cluster_key, queries_begin[query_idx], thread_idx).mean; + return true; + }); + } + + result.computed_distances = computed_distances; + result.visited_members = visited_members; + result.clusters = unique_clusters; + + (void)progress; + return result; + } + + private: + thread_lock_t thread_lock_(std::size_t thread_id) const usearch_noexcept_m { + if (thread_id != any_thread()) + return {*this, thread_id, false}; + + available_threads_mutex_.lock(); + usearch_assert_m(available_threads_.size(), "No available threads to lock"); + available_threads_.try_pop(thread_id); + available_threads_mutex_.unlock(); + return {*this, thread_id, true}; + } + + void thread_unlock_(std::size_t thread_id) const usearch_noexcept_m { + available_threads_mutex_.lock(); + usearch_assert_m(available_threads_.size() < available_threads_.capacity(), "Too many threads unlocked"); + available_threads_.push(thread_id); + available_threads_mutex_.unlock(); + } + + template + add_result_t add_( // + vector_key_t key, scalar_at const* vector, // + std::size_t thread, bool copy_vector, cast_punned_t const& cast) { + + if (!multi() && config().enable_key_lookups && contains(key)) + return add_result_t{}.failed("Duplicate keys not allowed in high-level wrappers"); + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data, copy_vector = true; + } + + // Check if there are some removed entries, whose nodes we can reuse + compressed_slot_t free_slot = default_free_value(); + { + std::unique_lock lock(free_keys_mutex_); + free_keys_.try_pop(free_slot); + } + + // Perform the insertion or the update + bool reuse_node = free_slot != default_free_value(); + auto on_success = [&](member_ref_t member) { + if (config_.enable_key_lookups) { + unique_lock_t slot_lock(slot_lookup_mutex_); + slot_lookup_.try_emplace(key_and_slot_t{key, static_cast(member.slot)}); + } + if (copy_vector) { + if (!reuse_node) + vectors_lookup_[member.slot] = vectors_tape_allocator_.allocate(metric_.bytes_per_vector()); + std::memcpy(vectors_lookup_[member.slot], vector_data, metric_.bytes_per_vector()); + } else + vectors_lookup_[member.slot] = (byte_t*)vector_data; + }; + + index_update_config_t update_config; + update_config.thread = lock.thread_id; + update_config.expansion = config_.expansion_add; + + metric_proxy_t metric{*this}; + return reuse_node // + ? typed_->update(typed_->iterator_at(free_slot), key, vector_data, metric, update_config, on_success) + : typed_->add(key, vector_data, metric, update_config, on_success); + } + + template + search_result_t search_(scalar_at const* vector, std::size_t wanted, predicate_at&& predicate, std::size_t thread, + bool exact, cast_punned_t const& cast) const { + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + index_search_config_t search_config; + search_config.thread = lock.thread_id; + search_config.expansion = config_.expansion_search; + search_config.exact = exact; + + vector_key_t free_key_copy = free_key_; + if (std::is_same::type, dummy_predicate_t>::value) { + auto allow = [free_key_copy](member_cref_t const& member) noexcept { + return (vector_key_t)member.key != free_key_copy; + }; + auto typed_result = typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + return search_result_t{std::move(typed_result), std::move(lock)}; + } else { + auto allow = [free_key_copy, &predicate](member_cref_t const& member) noexcept { + return (vector_key_t)member.key != free_key_copy && predicate(member.key); + }; + auto typed_result = typed_->search(vector_data, wanted, metric_proxy_t{*this}, search_config, allow); + return search_result_t{std::move(typed_result), std::move(lock)}; + } + } + + template + cluster_result_t cluster_( // + scalar_at const* vector, std::size_t level, // + std::size_t thread, cast_punned_t const& cast) const { + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + index_cluster_config_t cluster_config; + cluster_config.thread = lock.thread_id; + cluster_config.expansion = config_.expansion_search; + + vector_key_t free_key_copy = free_key_; + auto allow = [free_key_copy](member_cref_t const& member) noexcept { return member.key != free_key_copy; }; + return typed_->cluster(vector_data, level, metric_proxy_t{*this}, cluster_config, allow); + } + + template + aggregated_distances_t distance_between_( // + vector_key_t key, scalar_at const* vector, // + std::size_t thread, cast_punned_t const& cast) const { + + // Cast the vector, if needed for compatibility with `metric_` + thread_lock_t lock = thread_lock_(thread); + byte_t const* vector_data = reinterpret_cast(vector); + { + byte_t* casted_data = cast_buffer_.data() + metric_.bytes_per_vector() * lock.thread_id; + bool casted = cast(vector_data, dimensions(), casted_data); + if (casted) + vector_data = casted_data; + } + + // Check if such `key` is even present. + usearch_assert_m(config().enable_key_lookups, "Key lookups are disabled!"); + shared_lock_t slots_lock(slot_lookup_mutex_); + auto key_range = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + aggregated_distances_t result; + if (key_range.first == key_range.second) + return result; + + result.min = std::numeric_limits::max(); + result.max = std::numeric_limits::min(); + result.mean = 0; + result.count = 0; + + while (key_range.first != key_range.second) { + key_and_slot_t key_and_slot = *key_range.first; + byte_t const* a_vector = vectors_lookup_[key_and_slot.slot]; + byte_t const* b_vector = vector_data; + distance_t a_b_distance = metric_(a_vector, b_vector); + + result.mean += a_b_distance; + result.min = (std::min)(result.min, a_b_distance); + result.max = (std::max)(result.max, a_b_distance); + result.count++; + + // + ++key_range.first; + } + + result.mean /= result.count; + return result; + } + + void reindex_keys_() { + + // Estimate number of entries first + std::size_t count_total = typed_->size(); + std::size_t count_removed = 0; + for (std::size_t i = 0; i != count_total; ++i) { + auto member_slot = static_cast(i); + member_cref_t member = typed_->at(member_slot); + count_removed += member.key == free_key_; + } + + if (!count_removed && !config_.enable_key_lookups) + return; + + // Pull entries from the underlying `typed_` into either + // into `slot_lookup_`, or `free_keys_` if they are unused. + unique_lock_t lock(slot_lookup_mutex_); + slot_lookup_.clear(); + if (config_.enable_key_lookups) + slot_lookup_.reserve(count_total - count_removed); + free_keys_.clear(); + free_keys_.reserve(count_removed); + for (std::size_t i = 0; i != typed_->size(); ++i) { + auto member_slot = static_cast(i); + member_cref_t member = typed_->at(member_slot); + if (member.key == free_key_) + free_keys_.push(member_slot); + else if (config_.enable_key_lookups) + slot_lookup_.try_emplace(key_and_slot_t{vector_key_t(member.key), member_slot}); + } + } + + template + std::size_t get_(vector_key_t key, scalar_at* reconstructed, std::size_t vectors_limit, + cast_punned_t const& cast) const { + + if (!multi()) { + compressed_slot_t slot; + // Find the matching ID + { + shared_lock_t lock(slot_lookup_mutex_); + auto it = slot_lookup_.find(key_and_slot_t::any_slot(key)); + if (it == slot_lookup_.end()) + return false; + slot = (*it).slot; + } + // Export the entry + byte_t const* punned_vector = reinterpret_cast(vectors_lookup_[slot]); + bool casted = cast(punned_vector, dimensions(), (byte_t*)reconstructed); + if (!casted) + std::memcpy(reconstructed, punned_vector, metric_.bytes_per_vector()); + return true; + } else { + shared_lock_t lock(slot_lookup_mutex_); + auto equal_range_pair = slot_lookup_.equal_range(key_and_slot_t::any_slot(key)); + std::size_t count_exported = 0; + for (auto begin = equal_range_pair.first; + begin != equal_range_pair.second && count_exported != vectors_limit; ++begin, ++count_exported) { + // + compressed_slot_t slot = (*begin).slot; + byte_t const* punned_vector = reinterpret_cast(vectors_lookup_[slot]); + byte_t* reconstructed_vector = (byte_t*)reconstructed + metric_.bytes_per_vector() * count_exported; + bool casted = cast(punned_vector, dimensions(), reconstructed_vector); + if (!casted) + std::memcpy(reconstructed_vector, punned_vector, metric_.bytes_per_vector()); + } + return count_exported; + } + } +}; + +using index_dense_t = index_dense_gt<>; +using index_dense_big_t = index_dense_gt; + +/** + * @brief Adapts the Male-Optimal Stable Marriage algorithm for unequal sets + * to perform fast one-to-one matching between two large collections + * of vectors, using approximate nearest neighbors search. + * + * @param[inout] man_to_woman Container to map ::first keys to ::second. + * @param[inout] woman_to_man Container to map ::second keys to ::first. + * @param[in] executor Thread-pool to execute the job in parallel. + * @param[in] progress Callback to report the execution progress. + */ +template < // + + typename men_key_at, // + typename women_key_at, // + typename men_slot_at, // + typename women_slot_at, // + + typename man_to_woman_at = dummy_key_to_key_mapping_t, // + typename woman_to_man_at = dummy_key_to_key_mapping_t, // + typename executor_at = dummy_executor_t, // + typename progress_at = dummy_progress_t // + > +static join_result_t join( // + index_dense_gt const& men, // + index_dense_gt const& women, // + + index_join_config_t config = {}, // + man_to_woman_at&& man_to_woman = man_to_woman_at{}, // + woman_to_man_at&& woman_to_man = woman_to_man_at{}, // + executor_at&& executor = executor_at{}, // + progress_at&& progress = progress_at{}) { + + return men.join( // + women, config, // + std::forward(woman_to_man), // + std::forward(man_to_woman), // + std::forward(executor), // + std::forward(progress)); +} + +} // namespace usearch +} // namespace unum diff --git a/zig/usearch/include/index_plugins.hpp b/zig/usearch/include/index_plugins.hpp new file mode 100644 index 000000000..3efffcc0d --- /dev/null +++ b/zig/usearch/include/index_plugins.hpp @@ -0,0 +1,3033 @@ +#pragma once +#define __STDC_WANT_IEC_60559_TYPES_EXT__ +#include // `_Float16` +#include // `aligned_alloc` + +#include // `std::atomic` +#include // `std::chrono` +#include // `std::strncmp` +#include // `std::thread` + +#include "index.hpp" // `expected_gt` and macros + +#if !defined(USEARCH_USE_OPENMP) +#define USEARCH_USE_OPENMP 0 +#endif + +#if USEARCH_USE_OPENMP +#include // `omp_get_num_threads()` +#endif + +#if defined(USEARCH_DEFINED_LINUX) +#include // `getauxval()` +#endif + +#if !defined(USEARCH_USE_FP16LIB) +#if defined(__AVX512F__) +#define USEARCH_USE_FP16LIB 0 +#elif defined(USEARCH_DEFINED_ARM) +#include // `__fp16` +#define USEARCH_USE_FP16LIB 0 +#else +#define USEARCH_USE_FP16LIB 1 +#endif +#endif + +#if USEARCH_USE_FP16LIB +#include +#endif + +#if !defined(USEARCH_USE_SIMSIMD) +#define USEARCH_USE_SIMSIMD 0 +#endif + +#if USEARCH_USE_SIMSIMD +// Propagate the `f16` settings +#if defined(USEARCH_CAN_COMPILE_FP16) || defined(USEARCH_CAN_COMPILE_FLOAT16) +#if USEARCH_CAN_COMPILE_FP16 || USEARCH_CAN_COMPILE_FLOAT16 +#define SIMSIMD_NATIVE_F16 1 +#else +#define SIMSIMD_NATIVE_F16 0 +#endif +#endif +// Propagate the `bf16` settings +#if defined(USEARCH_CAN_COMPILE_BF16) || defined(USEARCH_CAN_COMPILE_BFLOAT16) +#if USEARCH_CAN_COMPILE_BF16 || USEARCH_CAN_COMPILE_BFLOAT16 +#define SIMSIMD_NATIVE_BF16 1 +#else +#define SIMSIMD_NATIVE_BF16 0 +#endif +#endif +// No problem, if some of the functions are unused or undefined +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wall" +#pragma GCC diagnostic ignored "-Wunused" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4101) // "Unused variables" +#pragma warning(disable : 4068) // "Unknown pragmas", when MSVC tries to read GCC pragmas +#endif // _MSC_VER +#include +#ifdef _MSC_VER +#pragma warning(pop) +#endif // _MSC_VER +#pragma GCC diagnostic pop +#endif + +namespace unum { +namespace usearch { + +using u40_t = uint40_t; +enum b1x8_t : unsigned char {}; + +struct uuid_t { + std::uint8_t octets[16]; +}; + +class f16_bits_t; +class bf16_bits_t; + +using f16_t = f16_bits_t; +using bf16_t = bf16_bits_t; + +using f64_t = double; +using f32_t = float; + +using u64_t = std::uint64_t; +using u32_t = std::uint32_t; +using u16_t = std::uint16_t; +using u8_t = std::uint8_t; + +using i64_t = std::int64_t; +using i32_t = std::int32_t; +using i16_t = std::int16_t; +using i8_t = std::int8_t; + +/** + * @brief Enumerates the most commonly used distance metrics, mostly for dense vector representations. + */ +enum class metric_kind_t : std::uint8_t { + unknown_k = 0, + // Classics: + ip_k = 'i', + cos_k = 'c', + l2sq_k = 'e', + + // Custom: + pearson_k = 'p', + haversine_k = 'h', + divergence_k = 'd', + + // Dense Sets: + hamming_k = 'b', + tanimoto_k = 't', + sorensen_k = 's', + + // Sparse Sets: + jaccard_k = 'j', +}; + +/** + * @brief Enumerates the most commonly used scalar types, mostly for dense vector representations. + * Doesn't include logical types, like complex numbers or quaternions. + */ +enum class scalar_kind_t : std::uint8_t { + unknown_k = 0, + // Custom: + b1x8_k = 1, + u40_k = 2, + uuid_k = 3, + bf16_k = 4, + // Common: + f64_k = 10, + f32_k = 11, + f16_k = 12, + f8_k = 13, + // Common Integral: + u64_k = 14, + u32_k = 15, + u16_k = 16, + u8_k = 17, + i64_k = 20, + i32_k = 21, + i16_k = 22, + i8_k = 23, +}; + +/** + * @brief Maps a scalar type to its corresponding scalar_kind_t enumeration value. + */ +template scalar_kind_t scalar_kind() noexcept { + if (std::is_same()) + return scalar_kind_t::b1x8_k; + if (std::is_same()) + return scalar_kind_t::u40_k; + if (std::is_same()) + return scalar_kind_t::uuid_k; + if (std::is_same()) + return scalar_kind_t::f64_k; + if (std::is_same()) + return scalar_kind_t::f32_k; + if (std::is_same()) + return scalar_kind_t::f16_k; + if (std::is_same()) + return scalar_kind_t::bf16_k; + if (std::is_same()) + return scalar_kind_t::i8_k; + if (std::is_same()) + return scalar_kind_t::u64_k; + if (std::is_same()) + return scalar_kind_t::u32_k; + if (std::is_same()) + return scalar_kind_t::u16_k; + if (std::is_same()) + return scalar_kind_t::u8_k; + if (std::is_same()) + return scalar_kind_t::i64_k; + if (std::is_same()) + return scalar_kind_t::i32_k; + if (std::is_same()) + return scalar_kind_t::i16_k; + if (std::is_same()) + return scalar_kind_t::i8_k; + return scalar_kind_t::unknown_k; +} + +/** + * @brief Converts an angle from degrees to radians. + */ +template at angle_to_radians(at angle) noexcept { return angle * at(3.14159265358979323846) / at(180); } + +/** + * @brief Readability helper to compute the square of a given value. + */ +template at square(at value) noexcept { return value * value; } + +/** + * @brief Clamps a value between a lower and upper bound using a custom comparator. Similar to `std::clamp`. + * https://en.cppreference.com/w/cpp/algorithm/clamp + */ +template inline at clamp(at v, at lo, at hi, compare_at comp) noexcept { + return comp(v, lo) ? lo : comp(hi, v) ? hi : v; +} + +/** + * @brief Clamps a value between a lower and upper bound. Similar to `std::clamp`. + * https://en.cppreference.com/w/cpp/algorithm/clamp + */ +template inline at clamp(at v, at lo, at hi) noexcept { + return usearch::clamp(v, lo, hi, std::less{}); +} + +/** + * @brief Compares two strings for equality, given a length for the first string. + */ +inline bool str_equals(char const* first_begin, std::size_t first_len, char const* second_begin) noexcept { + std::size_t second_len = std::strlen(second_begin); + return first_len == second_len && std::strncmp(first_begin, second_begin, first_len) == 0; +} + +/** + * @brief Returns the number of bits required to represent a scalar type. + */ +inline std::size_t bits_per_scalar(scalar_kind_t scalar_kind) noexcept { + switch (scalar_kind) { + case scalar_kind_t::uuid_k: return 128; + case scalar_kind_t::u40_k: return 40; + case scalar_kind_t::bf16_k: return 16; + case scalar_kind_t::b1x8_k: return 1; + case scalar_kind_t::u64_k: return 64; + case scalar_kind_t::i64_k: return 64; + case scalar_kind_t::f64_k: return 64; + case scalar_kind_t::u32_k: return 32; + case scalar_kind_t::i32_k: return 32; + case scalar_kind_t::f32_k: return 32; + case scalar_kind_t::u16_k: return 16; + case scalar_kind_t::i16_k: return 16; + case scalar_kind_t::f16_k: return 16; + case scalar_kind_t::u8_k: return 8; + case scalar_kind_t::i8_k: return 8; + case scalar_kind_t::f8_k: return 8; + default: return 0; + } +} + +/** + * @brief Returns the number of bits in a scalar word for a given scalar type. + * Equivalent to `bits_per_scalar` for types that are not bit-packed. + */ +inline std::size_t bits_per_scalar_word(scalar_kind_t scalar_kind) noexcept { + switch (scalar_kind) { + case scalar_kind_t::uuid_k: return 128; + case scalar_kind_t::u40_k: return 40; + case scalar_kind_t::bf16_k: return 16; + case scalar_kind_t::b1x8_k: return 8; + case scalar_kind_t::u64_k: return 64; + case scalar_kind_t::i64_k: return 64; + case scalar_kind_t::f64_k: return 64; + case scalar_kind_t::u32_k: return 32; + case scalar_kind_t::i32_k: return 32; + case scalar_kind_t::f32_k: return 32; + case scalar_kind_t::u16_k: return 16; + case scalar_kind_t::i16_k: return 16; + case scalar_kind_t::f16_k: return 16; + case scalar_kind_t::u8_k: return 8; + case scalar_kind_t::i8_k: return 8; + case scalar_kind_t::f8_k: return 8; + default: return 0; + } +} + +/** + * @brief Returns the string name of a given scalar type. + */ +inline char const* scalar_kind_name(scalar_kind_t scalar_kind) noexcept { + switch (scalar_kind) { + case scalar_kind_t::uuid_k: return "uuid"; + case scalar_kind_t::u40_k: return "u40"; + case scalar_kind_t::bf16_k: return "bf16"; + case scalar_kind_t::b1x8_k: return "b1x8"; + case scalar_kind_t::u64_k: return "u64"; + case scalar_kind_t::i64_k: return "i64"; + case scalar_kind_t::f64_k: return "f64"; + case scalar_kind_t::u32_k: return "u32"; + case scalar_kind_t::i32_k: return "i32"; + case scalar_kind_t::f32_k: return "f32"; + case scalar_kind_t::u16_k: return "u16"; + case scalar_kind_t::i16_k: return "i16"; + case scalar_kind_t::f16_k: return "f16"; + case scalar_kind_t::u8_k: return "u8"; + case scalar_kind_t::i8_k: return "i8"; + case scalar_kind_t::f8_k: return "f8"; + default: return ""; + } +} + +/** + * @brief Returns the string name of a given distance metric. + */ +inline char const* metric_kind_name(metric_kind_t metric) noexcept { + switch (metric) { + case metric_kind_t::unknown_k: return "unknown"; + case metric_kind_t::ip_k: return "ip"; + case metric_kind_t::cos_k: return "cos"; + case metric_kind_t::l2sq_k: return "l2sq"; + case metric_kind_t::pearson_k: return "pearson"; + case metric_kind_t::haversine_k: return "haversine"; + case metric_kind_t::divergence_k: return "divergence"; + case metric_kind_t::jaccard_k: return "jaccard"; + case metric_kind_t::hamming_k: return "hamming"; + case metric_kind_t::tanimoto_k: return "tanimoto"; + case metric_kind_t::sorensen_k: return "sorensen"; + default: return ""; + } +} + +/** + * @brief Parses a string to identify the corresponding `scalar_kind_t` enumeration value. + */ +inline expected_gt scalar_kind_from_name(char const* name, std::size_t len) { + expected_gt parsed; + if (str_equals(name, len, "f32")) + parsed.result = scalar_kind_t::f32_k; + else if (str_equals(name, len, "f64")) + parsed.result = scalar_kind_t::f64_k; + else if (str_equals(name, len, "f16")) + parsed.result = scalar_kind_t::f16_k; + else if (str_equals(name, len, "bf16")) + parsed.result = scalar_kind_t::bf16_k; + else if (str_equals(name, len, "i8")) + parsed.result = scalar_kind_t::i8_k; + else if (str_equals(name, len, "b1")) + parsed.result = scalar_kind_t::b1x8_k; + else + parsed.failed("Unknown type, choose: f64, f32, f16, bf16, i8, b1"); + return parsed; +} + +/** + * @brief Parses a string to identify the corresponding `scalar_kind_t` enumeration value. + */ +inline expected_gt scalar_kind_from_name(char const* name) { + return scalar_kind_from_name(name, std::strlen(name)); +} + +/** + * @brief Parses a string to identify the corresponding `metric_kind_t` enumeration value. + */ +inline expected_gt metric_from_name(char const* name, std::size_t len) { + expected_gt parsed; + if (str_equals(name, len, "l2sq") || str_equals(name, len, "euclidean_sq")) { + parsed.result = metric_kind_t::l2sq_k; + } else if (str_equals(name, len, "ip") || str_equals(name, len, "inner") || str_equals(name, len, "dot")) { + parsed.result = metric_kind_t::ip_k; + } else if (str_equals(name, len, "cos") || str_equals(name, len, "angular")) { + parsed.result = metric_kind_t::cos_k; + } else if (str_equals(name, len, "haversine")) { + parsed.result = metric_kind_t::haversine_k; + } else if (str_equals(name, len, "divergence")) { + parsed.result = metric_kind_t::divergence_k; + } else if (str_equals(name, len, "pearson")) { + parsed.result = metric_kind_t::pearson_k; + } else if (str_equals(name, len, "hamming")) { + parsed.result = metric_kind_t::hamming_k; + } else if (str_equals(name, len, "tanimoto")) { + parsed.result = metric_kind_t::tanimoto_k; + } else if (str_equals(name, len, "sorensen")) { + parsed.result = metric_kind_t::sorensen_k; + } else + parsed.failed("Unknown distance, choose: l2sq, ip, cos, haversine, divergence, jaccard, pearson, hamming, " + "tanimoto, sorensen"); + return parsed; +} + +/** + * @brief Parses a string to identify the corresponding `metric_kind_t` enumeration value. + */ +inline expected_gt metric_from_name(char const* name) { + return metric_from_name(name, std::strlen(name)); +} + +/** + * @brief Convenience function to upcast a half-precision floating point number to a single-precision one. + */ +inline float f16_to_f32(std::uint16_t u16) noexcept { +#if USEARCH_USE_FP16LIB + return fp16_ieee_to_fp32_value(u16); +#elif USEARCH_USE_SIMSIMD + return simsimd_f16_to_f32((simsimd_f16_t const*)&u16); +#else +#warning "It's recommended to use SimSIMD and fp16lib for half-precision numerics" + _Float16 f16; + std::memcpy(&f16, &u16, sizeof(std::uint16_t)); + return float(f16); +#endif +} + +/** + * @brief Convenience function to downcast a single-precision floating point number to a half-precision one. + */ +inline std::uint16_t f32_to_f16(float f32) noexcept { +#if USEARCH_USE_FP16LIB + return fp16_ieee_from_fp32_value(f32); +#elif USEARCH_USE_SIMSIMD + std::uint16_t result; + simsimd_f32_to_f16(f32, (simsimd_f16_t*)&result); + return result; +#else +#warning "It's recommended to use SimSIMD and fp16lib for half-precision numerics" + _Float16 f16 = _Float16(f32); + std::uint16_t u16; + std::memcpy(&u16, &f16, sizeof(std::uint16_t)); + return u16; +#endif +} + +/** + * @brief Convenience function to upcast a brain-floating point number to a single-precision one. + * https://github.com/ashvardanian/SimSIMD/blob/ff51434d90c66f916e94ff05b24530b127aa4cff/include/simsimd/types.h#L395-L410 + */ +inline float bf16_to_f32(std::uint16_t u16) noexcept { +#if USEARCH_USE_SIMSIMD + return simsimd_bf16_to_f32((simsimd_bf16_t const*)&u16); +#else + union float_or_unsigned_int_t { + float f; + unsigned int i; + } conv; + conv.i = u16 << 16; // Zero extends the mantissa + return conv.f; +#endif +} + +/** + * @brief Convenience function to downcast a single-precision floating point number to a brain-floating point one. + * https://github.com/ashvardanian/SimSIMD/blob/ff51434d90c66f916e94ff05b24530b127aa4cff/include/simsimd/types.h#L412-L425 + */ +inline std::uint16_t f32_to_bf16(float f32) noexcept { +#if USEARCH_USE_SIMSIMD + std::uint16_t result; + simsimd_f32_to_bf16(f32, (simsimd_bf16_t*)&result); + return result; +#else + union float_or_unsigned_int_t { + float f; + unsigned int i; + } conv; + conv.f = f32; + conv.i >>= 16; + conv.i &= 0xFFFF; + return (unsigned short)conv.i; +#endif +} + +/** + * @brief Numeric type for the IEEE 754 half-precision floating point. + * If hardware support isn't available, falls back to a hardware + * agnostic in-software implementation. + */ +class f16_bits_t { + std::uint16_t uint16_{}; + + public: + inline f16_bits_t() noexcept : uint16_(0) {} + inline f16_bits_t(f16_bits_t&&) = default; + inline f16_bits_t& operator=(f16_bits_t&&) = default; + inline f16_bits_t(f16_bits_t const&) = default; + inline f16_bits_t& operator=(f16_bits_t const&) = default; + + inline operator float() const noexcept { return f16_to_f32(uint16_); } + inline explicit operator bool() const noexcept { return f16_to_f32(uint16_) > 0.5f; } + + inline f16_bits_t(int v) noexcept : uint16_(f32_to_f16(static_cast(v))) {} + inline f16_bits_t(bool v) noexcept : uint16_(f32_to_f16(static_cast(v))) {} + inline f16_bits_t(float v) noexcept : uint16_(f32_to_f16(v)) {} + inline f16_bits_t(double v) noexcept : uint16_(f32_to_f16(static_cast(v))) {} + + inline bool operator<(f16_bits_t const& other) const noexcept { return float(*this) < float(other); } + + inline f16_bits_t operator+(f16_bits_t other) const noexcept { return {float(*this) + float(other)}; } + inline f16_bits_t operator-(f16_bits_t other) const noexcept { return {float(*this) - float(other)}; } + inline f16_bits_t operator*(f16_bits_t other) const noexcept { return {float(*this) * float(other)}; } + inline f16_bits_t operator/(f16_bits_t other) const noexcept { return {float(*this) / float(other)}; } + inline float operator+(float other) const noexcept { return float(*this) + other; } + inline float operator-(float other) const noexcept { return float(*this) - other; } + inline float operator*(float other) const noexcept { return float(*this) * other; } + inline float operator/(float other) const noexcept { return float(*this) / other; } + inline double operator+(double other) const noexcept { return float(*this) + other; } + inline double operator-(double other) const noexcept { return float(*this) - other; } + inline double operator*(double other) const noexcept { return float(*this) * other; } + inline double operator/(double other) const noexcept { return float(*this) / other; } + + inline f16_bits_t& operator+=(float v) noexcept { + uint16_ = f32_to_f16(v + f16_to_f32(uint16_)); + return *this; + } + + inline f16_bits_t& operator-=(float v) noexcept { + uint16_ = f32_to_f16(v - f16_to_f32(uint16_)); + return *this; + } + + inline f16_bits_t& operator*=(float v) noexcept { + uint16_ = f32_to_f16(v * f16_to_f32(uint16_)); + return *this; + } + + inline f16_bits_t& operator/=(float v) noexcept { + uint16_ = f32_to_f16(v / f16_to_f32(uint16_)); + return *this; + } +}; + +#if USEARCH_USE_OPENMP +#pragma omp declare reduction(+ : unum::usearch::f16_bits_t : omp_out = omp_out + omp_in) \ + initializer(omp_priv = unum::usearch::f16_bits_t()) +#endif + +/** + * @brief Numeric type for brain-floating point half-precision floating point. + * If hardware support isn't available, falls back to a hardware + * agnostic in-software implementation. + */ +class bf16_bits_t { + std::uint16_t uint16_{}; + + public: + inline bf16_bits_t() noexcept : uint16_(0) {} + inline bf16_bits_t(bf16_bits_t&&) = default; + inline bf16_bits_t& operator=(bf16_bits_t&&) = default; + inline bf16_bits_t(bf16_bits_t const&) = default; + inline bf16_bits_t& operator=(bf16_bits_t const&) = default; + + inline operator float() const noexcept { return bf16_to_f32(uint16_); } + inline explicit operator bool() const noexcept { return bf16_to_f32(uint16_) > 0.5f; } + + inline bf16_bits_t(int v) noexcept : uint16_(f32_to_bf16(static_cast(v))) {} + inline bf16_bits_t(bool v) noexcept : uint16_(f32_to_bf16(static_cast(v))) {} + inline bf16_bits_t(float v) noexcept : uint16_(f32_to_bf16(v)) {} + inline bf16_bits_t(double v) noexcept : uint16_(f32_to_bf16(static_cast(v))) {} + + inline bool operator<(bf16_bits_t const& other) const noexcept { return float(*this) < float(other); } + + inline bf16_bits_t operator+(bf16_bits_t other) const noexcept { return {float(*this) + float(other)}; } + inline bf16_bits_t operator-(bf16_bits_t other) const noexcept { return {float(*this) - float(other)}; } + inline bf16_bits_t operator*(bf16_bits_t other) const noexcept { return {float(*this) * float(other)}; } + inline bf16_bits_t operator/(bf16_bits_t other) const noexcept { return {float(*this) / float(other)}; } + inline float operator+(float other) const noexcept { return float(*this) + other; } + inline float operator-(float other) const noexcept { return float(*this) - other; } + inline float operator*(float other) const noexcept { return float(*this) * other; } + inline float operator/(float other) const noexcept { return float(*this) / other; } + inline double operator+(double other) const noexcept { return float(*this) + other; } + inline double operator-(double other) const noexcept { return float(*this) - other; } + inline double operator*(double other) const noexcept { return float(*this) * other; } + inline double operator/(double other) const noexcept { return float(*this) / other; } + + inline bf16_bits_t& operator+=(float v) noexcept { + uint16_ = f32_to_bf16(v + bf16_to_f32(uint16_)); + return *this; + } + + inline bf16_bits_t& operator-=(float v) noexcept { + uint16_ = f32_to_bf16(v - bf16_to_f32(uint16_)); + return *this; + } + + inline bf16_bits_t& operator*=(float v) noexcept { + uint16_ = f32_to_bf16(v * bf16_to_f32(uint16_)); + return *this; + } + + inline bf16_bits_t& operator/=(float v) noexcept { + uint16_ = f32_to_bf16(v / bf16_to_f32(uint16_)); + return *this; + } + + inline bf16_bits_t& operator=(float v) noexcept { + uint16_ = f32_to_bf16(v); + return *this; + } +}; + +#if USEARCH_USE_OPENMP +#pragma omp declare reduction(+ : unum::usearch::bf16_bits_t : omp_out = omp_out + omp_in) \ + initializer(omp_priv = unum::usearch::bf16_bits_t()) +#endif + +/** + * @brief An STL-based executor or a "thread-pool" for parallel execution. + * Isn't efficient for small batches, as it recreates the threads on every call. + */ +class executor_stl_t { + std::size_t threads_count_{}; + + struct jthread_t { + std::thread native_; + bool initialized_ = false; + + jthread_t() = default; + jthread_t(jthread_t&&) = default; + jthread_t(jthread_t const&) = delete; + template + jthread_t(callable_at&& func) : native_([=]() { func(); }), initialized_(true) {} + + ~jthread_t() { + if (initialized_ && native_.joinable()) + native_.join(); + } + }; + + public: + /** + * @param threads_count The number of threads to be used for parallel execution. + */ + executor_stl_t(std::size_t threads_count = 0) noexcept + : threads_count_(threads_count ? threads_count : std::thread::hardware_concurrency()) {} + + /** + * @return Maximum number of threads available to the executor. + */ + std::size_t size() const noexcept { return threads_count_; } + + /** + * @brief Executes a fixed number of tasks using the specified thread-aware function. + * @param tasks The total number of tasks to be executed. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { + buffer_gt threads_pool(threads_count_ - 1); // Allocate space for threads minus the main thread + std::size_t tasks_per_thread = tasks; + std::size_t threads_count = (std::min)(threads_count_, tasks); + if (threads_count > 1) { + tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); + for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { + new (&threads_pool[thread_idx - 1]) jthread_t([=]() { + for (std::size_t task_idx = thread_idx * tasks_per_thread; + task_idx < (std::min)(tasks, thread_idx * tasks_per_thread + tasks_per_thread); ++task_idx) + thread_aware_function(thread_idx, task_idx); + }); + } + } + for (std::size_t task_idx = 0; task_idx < (std::min)(tasks, tasks_per_thread); ++task_idx) + thread_aware_function(0, task_idx); + } + + /** + * @brief Executes limited number of tasks using the specified thread-aware function. + * @param tasks The upper bound on the number of tasks. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { + buffer_gt threads_pool(threads_count_ - 1); + std::size_t tasks_per_thread = tasks; + std::size_t threads_count = (std::min)(threads_count_, tasks); + std::atomic_bool stop{false}; + if (threads_count > 1) { + tasks_per_thread = (tasks / threads_count) + ((tasks % threads_count) != 0); + for (std::size_t thread_idx = 1; thread_idx < threads_count; ++thread_idx) { + new (&threads_pool[thread_idx - 1]) jthread_t([=, &stop]() { + for (std::size_t task_idx = thread_idx * tasks_per_thread; + task_idx < (std::min)(tasks, thread_idx * tasks_per_thread + tasks_per_thread) && + !stop.load(std::memory_order_relaxed); + ++task_idx) + if (!thread_aware_function(thread_idx, task_idx)) + stop.store(true, std::memory_order_relaxed); + }); + } + } + for (std::size_t task_idx = 0; + task_idx < (std::min)(tasks, tasks_per_thread) && !stop.load(std::memory_order_relaxed); ++task_idx) + if (!thread_aware_function(0, task_idx)) + stop.store(true, std::memory_order_relaxed); + } + + /** + * @brief Saturates every available thread with the given workload, until they finish. + * @param thread_aware_function The thread-aware function to be called for each thread index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { + if (threads_count_ == 1) + return thread_aware_function(0); + buffer_gt threads_pool(threads_count_ - 1); + for (std::size_t thread_idx = 1; thread_idx < threads_count_; ++thread_idx) + new (&threads_pool[thread_idx - 1]) jthread_t([=]() { thread_aware_function(thread_idx); }); + thread_aware_function(0); + } +}; + +#if USEARCH_USE_OPENMP + +/** + * @brief An OpenMP-based executor or a "thread-pool" for parallel execution. + * Is the preferred implementation, when available, and maximum performance is needed. + */ +class executor_openmp_t { + public: + /** + * @param threads_count The number of threads to be used for parallel execution. + */ + executor_openmp_t(std::size_t threads_count = 0) noexcept { + omp_set_num_threads(static_cast(threads_count ? threads_count : std::thread::hardware_concurrency())); + } + + /** + * @return Maximum number of threads available to the executor. + */ + std::size_t size() const noexcept { return omp_get_max_threads(); } + + /** + * @brief Executes tasks in bulk using the specified thread-aware function. + * @param tasks The total number of tasks to be executed. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void fixed(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { +#pragma omp parallel for schedule(dynamic, 1) + for (std::size_t i = 0; i != tasks; ++i) { + thread_aware_function(omp_get_thread_num(), i); + } + } + + /** + * @brief Executes tasks in bulk using the specified thread-aware function. + * @param tasks The total number of tasks to be executed. + * @param thread_aware_function The thread-aware function to be called for each thread index and task index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void dynamic(std::size_t tasks, thread_aware_function_at&& thread_aware_function) noexcept(false) { + // OpenMP cancellation points are not yet available on most platforms, and require + // the `OMP_CANCELLATION` environment variable to be set. + // http://jakascorner.com/blog/2016/08/omp-cancel.html + // if (omp_get_cancellation()) { + // #pragma omp parallel for schedule(dynamic, 1) + // for (std::size_t i = 0; i != tasks; ++i) { + // #pragma omp cancellation point for + // if (!thread_aware_function(omp_get_thread_num(), i)) { + // #pragma omp cancel for + // } + // } + // } + std::atomic_bool stop{false}; +#pragma omp parallel for schedule(dynamic, 1) shared(stop) + for (std::size_t i = 0; i != tasks; ++i) { + if (!stop.load(std::memory_order_relaxed) && !thread_aware_function(omp_get_thread_num(), i)) + stop.store(true, std::memory_order_relaxed); + } + } + + /** + * @brief Saturates every available thread with the given workload, until they finish. + * @param thread_aware_function The thread-aware function to be called for each thread index. + * @throws If an exception occurs during execution of the thread-aware function. + */ + template + void parallel(thread_aware_function_at&& thread_aware_function) noexcept(false) { +#pragma omp parallel + { + thread_aware_function(omp_get_thread_num()); + } + } +}; + +using executor_default_t = executor_openmp_t; + +#else + +using executor_default_t = executor_stl_t; + +#endif + +/** + * @brief Uses OS-specific APIs for aligned memory allocations. + * Available since C11, but only C++17, so we wrap the C version. + */ +template // +class aligned_allocator_gt { + public: + using value_type = element_at; + using size_type = std::size_t; + using pointer = element_at*; + using const_pointer = element_at const*; + template struct rebind { + using other = aligned_allocator_gt; + }; + + constexpr std::size_t alignment() const { return alignment_ak; } + + pointer allocate(size_type length) const { + std::size_t length_bytes = alignment_ak * divide_round_up(length * sizeof(value_type)); + // Avoid overflow + if (length > length_bytes) + return nullptr; + std::size_t alignment = alignment_ak; +#if defined(USEARCH_DEFINED_WINDOWS) + return (pointer)_aligned_malloc(length_bytes, alignment); +#elif defined(USEARCH_DEFINED_APPLE) || defined(USEARCH_DEFINED_ANDROID) + // Apple Clang keeps complaining that `aligned_alloc` is only available + // with macOS 10.15 and newer or Android API >= 28, so let's use `posix_memalign` there. + void* result = nullptr; + int status = posix_memalign(&result, alignment, length_bytes); + return status == 0 ? (pointer)result : nullptr; +#else + return (pointer)aligned_alloc(alignment, length_bytes); +#endif + } + + void deallocate(pointer begin, size_type) const { +#if defined(USEARCH_DEFINED_WINDOWS) + _aligned_free(begin); +#else + free(begin); +#endif + } +}; + +using aligned_allocator_t = aligned_allocator_gt<>; + +/** + * @brief A simple RAM-page allocator that uses the OS-specific APIs for memory allocation. + * Shouldn't be used frequently, as system calls are slow. + */ +class page_allocator_t { + public: + static constexpr std::size_t page_size() { return 4096; } + + /** + * @brief Allocates an @b uninitialized block of memory of the specified size. + * @param count_bytes The number of bytes to allocate. + * @return A pointer to the allocated memory block, or `nullptr` if allocation fails. + */ + byte_t* allocate(std::size_t count_bytes) const noexcept { + count_bytes = divide_round_up(count_bytes, page_size()) * page_size(); +#if defined(USEARCH_DEFINED_WINDOWS) + return (byte_t*)(::VirtualAlloc(NULL, count_bytes, MEM_COMMIT | MEM_RESERVE, PAGE_READWRITE)); +#else + return (byte_t*)mmap(NULL, count_bytes, PROT_WRITE | PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS, 0, 0); +#endif + } + + void deallocate(byte_t* page_pointer, std::size_t count_bytes) const noexcept { +#if defined(USEARCH_DEFINED_WINDOWS) + ::VirtualFree(page_pointer, 0, MEM_RELEASE); +#else + count_bytes = divide_round_up(count_bytes, page_size()) * page_size(); + munmap(page_pointer, count_bytes); +#endif + } +}; + +/** + * @brief Memory-mapping allocator designed for "alloc many, free at once" usage patterns. + * @b Thread-safe, @b except constructors and destructors. + * + * Using this memory allocator won't affect your overall speed much, as that is not the bottleneck. + * However, it can drastically improve memory usage especially for huge indexes of small vectors. + */ +template class memory_mapping_allocator_gt { + + static constexpr std::size_t min_capacity() { return 1024 * 1024 * 4; } + static constexpr std::size_t capacity_multiplier() { return 2; } + static constexpr std::size_t head_size() { + /// Pointer to the the previous arena and the size of the current one. + return divide_round_up(sizeof(byte_t*) + sizeof(std::size_t)) * alignment_ak; + } + + std::mutex mutex_; + byte_t* last_arena_ = nullptr; + std::size_t last_usage_ = head_size(); + std::size_t last_capacity_ = min_capacity(); + std::size_t wasted_space_ = 0; + + public: + using value_type = byte_t; + using size_type = std::size_t; + using pointer = byte_t*; + using const_pointer = byte_t const*; + + memory_mapping_allocator_gt() = default; + memory_mapping_allocator_gt(memory_mapping_allocator_gt&& other) noexcept + : last_arena_(exchange(other.last_arena_, nullptr)), last_usage_(exchange(other.last_usage_, 0)), + last_capacity_(exchange(other.last_capacity_, 0)), wasted_space_(exchange(other.wasted_space_, 0)) {} + + memory_mapping_allocator_gt& operator=(memory_mapping_allocator_gt&& other) noexcept { + std::swap(last_arena_, other.last_arena_); + std::swap(last_usage_, other.last_usage_); + std::swap(last_capacity_, other.last_capacity_); + std::swap(wasted_space_, other.wasted_space_); + return *this; + } + + ~memory_mapping_allocator_gt() noexcept { reset(); } + + /** + * @brief Discards all previously allocated memory buffers. + */ + void reset() noexcept { + byte_t* last_arena = last_arena_; + while (last_arena) { + byte_t* previous_arena = nullptr; + std::memcpy(&previous_arena, last_arena, sizeof(byte_t*)); + std::size_t last_cap = 0; + std::memcpy(&last_cap, last_arena + sizeof(byte_t*), sizeof(std::size_t)); + page_allocator_t{}.deallocate(last_arena, last_cap); + last_arena = previous_arena; + } + + // Clear the references: + last_arena_ = nullptr; + last_usage_ = head_size(); + last_capacity_ = min_capacity(); + wasted_space_ = 0; + } + + /** + * @brief Copy constructor. + * @note This is a no-op copy constructor since the allocator is not copyable. + */ + memory_mapping_allocator_gt(memory_mapping_allocator_gt const&) noexcept {} + + /** + * @brief Copy assignment operator. + * @note This is a no-op copy assignment operator since the allocator is not copyable. + * @return Reference to the allocator after the assignment. + */ + memory_mapping_allocator_gt& operator=(memory_mapping_allocator_gt const&) noexcept { + reset(); + return *this; + } + + /** + * @brief Allocates an @b uninitialized block of memory of the specified size. + * @param count_bytes The number of bytes to allocate. + * @return A pointer to the allocated memory block, or `nullptr` if allocation fails. + */ + inline byte_t* allocate(std::size_t count_bytes) noexcept { + std::size_t extended_bytes = divide_round_up(count_bytes) * alignment_ak; + std::unique_lock lock(mutex_); + if (!last_arena_ || (last_usage_ + extended_bytes >= last_capacity_)) { + std::size_t new_cap = (std::max)(last_capacity_, ceil2(extended_bytes)) * capacity_multiplier(); + byte_t* new_arena = page_allocator_t{}.allocate(new_cap); + if (!new_arena) + return nullptr; + std::memcpy(new_arena, &last_arena_, sizeof(byte_t*)); + std::memcpy(new_arena + sizeof(byte_t*), &new_cap, sizeof(std::size_t)); + + wasted_space_ += total_reserved(); + last_arena_ = new_arena; + last_capacity_ = new_cap; + last_usage_ = head_size(); + } + + wasted_space_ += extended_bytes - count_bytes; + return last_arena_ + exchange(last_usage_, last_usage_ + extended_bytes); + } + + /** + * @brief Returns the amount of memory used by the allocator across all arenas. + * @return The amount of space in bytes. + */ + std::size_t total_allocated() const noexcept { + if (!last_arena_) + return 0; + std::size_t total_used = 0; + std::size_t last_capacity = last_capacity_; + do { + total_used += last_capacity; + last_capacity /= capacity_multiplier(); + } while (last_capacity >= min_capacity()); + return total_used; + } + + /** + * @brief Returns the amount of wasted space due to alignment. + * @return The amount of wasted space in bytes. + */ + std::size_t total_wasted() const noexcept { return wasted_space_; } + + /** + * @brief Returns the amount of remaining memory already reserved but not yet used. + * @return The amount of reserved memory in bytes. + */ + std::size_t total_reserved() const noexcept { return last_arena_ ? last_capacity_ - last_usage_ : 0; } + + /** + * @warning The very first memory de-allocation discards all the arenas! + */ + void deallocate(byte_t* = nullptr, std::size_t = 0) noexcept { reset(); } +}; + +using memory_mapping_allocator_t = memory_mapping_allocator_gt<>; + +/** + * @brief C++11 userspace implementation of an oversimplified `std::shared_mutex`, + * that assumes rare interleaving of shared and unique locks. It's not fair, + * but requires only a single 32-bit atomic integer to work. + */ +class unfair_shared_mutex_t { + /** Any positive integer describes the number of concurrent readers */ + enum state_t : std::int32_t { + idle_k = 0, + writing_k = -1, + }; + std::atomic state_{idle_k}; + + public: + inline void lock() noexcept { + std::int32_t raw; + relock: + raw = idle_k; + if (!state_.compare_exchange_weak(raw, writing_k, std::memory_order_acquire, std::memory_order_relaxed)) { + std::this_thread::yield(); + goto relock; + } + } + + inline void unlock() noexcept { state_.store(idle_k, std::memory_order_release); } + + inline void lock_shared() noexcept { + std::int32_t raw; + relock_shared: + raw = state_.load(std::memory_order_acquire); + // Spin while it's uniquely locked + if (raw == writing_k) { + std::this_thread::yield(); + goto relock_shared; + } + // Try incrementing the counter + if (!state_.compare_exchange_weak(raw, raw + 1, std::memory_order_acquire, std::memory_order_relaxed)) { + std::this_thread::yield(); + goto relock_shared; + } + } + + inline void unlock_shared() noexcept { state_.fetch_sub(1, std::memory_order_release); } + + /** + * @brief Try upgrades the current `lock_shared()` to a unique `lock()` state. + */ + inline bool try_escalate() noexcept { + std::int32_t one_read = 1; + return state_.compare_exchange_weak(one_read, writing_k, std::memory_order_acquire, std::memory_order_relaxed); + } + + /** + * @brief Escalates current lock potentially loosing control in the middle. + * It's a shortcut for `try_escalate`-`unlock_shared`-`lock` trio. + */ + inline void unsafe_escalate() noexcept { + if (!try_escalate()) { + unlock_shared(); + lock(); + } + } + + /** + * @brief Upgrades the current `lock_shared()` to a unique `lock()` state. + */ + inline void escalate() noexcept { + while (!try_escalate()) + std::this_thread::yield(); + } + + /** + * @brief De-escalation of a previously escalated state. + */ + inline void de_escalate() noexcept { + std::int32_t one_read = 1; + state_.store(one_read, std::memory_order_release); + } +}; + +template class shared_lock_gt { + mutex_at& mutex_; + + public: + inline explicit shared_lock_gt(mutex_at& m) noexcept : mutex_(m) { mutex_.lock_shared(); } + inline ~shared_lock_gt() noexcept { mutex_.unlock_shared(); } +}; + +/** + * @brief Utility class used to cast arrays of one scalar type to another, + * avoiding unnecessary conversions. + */ +template struct cast_gt { + static bool try_(byte_t const* input, std::size_t dim, byte_t* output) noexcept { + from_scalar_at const* typed_input = reinterpret_cast(input); + to_scalar_at* typed_output = reinterpret_cast(output); + auto converter = [](from_scalar_at from) { return to_scalar_at(from); }; + std::transform(typed_input, typed_input + dim, typed_output, converter); + return true; + } +}; + +template <> struct cast_gt { + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } +}; + +template <> struct cast_gt { + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } +}; + +template <> struct cast_gt { + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } +}; + +template <> struct cast_gt { + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } +}; + +template <> struct cast_gt { + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } +}; + +template <> struct cast_gt { + static bool try_(byte_t const*, std::size_t, byte_t*) noexcept { return false; } +}; + +template struct cast_to_b1x8_gt { + inline static bool try_(byte_t const* input, std::size_t dim, byte_t* output) noexcept { + from_scalar_at const* typed_input = reinterpret_cast(input); + unsigned char* typed_output = reinterpret_cast(output); + std::memset(typed_output, 0, dim / CHAR_BIT); + for (std::size_t i = 0; i != dim; ++i) + // Converting from scalar types to boolean isn't trivial and depends on the type. + // The most common case is to consider all positive values as `true` and all others as `false`. + // - `bool(0.00001f)` converts to 1 + // - `bool(-0.00001f)` converts to 1 + // - `bool(0)` converts to 0 + // - `bool(-0)` converts to 0 + // - `bool(std::numeric_limits::infinity())` converts to 1 + // - `bool(std::numeric_limits::epsilon())` converts to 1 + // - `bool(std::numeric_limits::signaling_NaN())` converts to 1 + // - `bool(std::numeric_limits::denorm_min())` converts to 1 + typed_output[i / CHAR_BIT] |= bool(typed_input[i] > 0) ? (128 >> (i & (CHAR_BIT - 1))) : 0; + return true; + } +}; + +template struct cast_from_b1x8_gt { + static bool try_(byte_t const* input, std::size_t dim, byte_t* output) noexcept { + unsigned char const* typed_input = reinterpret_cast(input); + to_scalar_at* typed_output = reinterpret_cast(output); + for (std::size_t i = 0; i != dim; ++i) + // We can't entirely reconstruct the original scalar type from a boolean. + // The simplest variant would be to map set bits to ones, and unset bits to zeros. + typed_output[i] = bool(typed_input[i / CHAR_BIT] & (128 >> (i & (CHAR_BIT - 1)))); + return true; + } +}; + +template struct cast_to_i8_gt { + inline static bool try_(byte_t const* input, std::size_t dim, byte_t* output) noexcept { + from_scalar_at const* typed_input = reinterpret_cast(input); + std::int8_t* typed_output = reinterpret_cast(output); + // Unlike other casting mechanisms, switching to small range integers is a two step procedure. + // First we want to estimate the magnitude of the vector to scale into [-1.0, 1.0] interval, + // instead of clamping. And then we scale the values into the [-127, 127] range. + // ! This makes an assumption, that the distance metric is dot-product-like, which may not + // ! be true in many cases, so it's recommended to avoid automatic casting from floats to + // ! integers. + double magnitude = 0.0; + for (std::size_t i = 0; i != dim; ++i) + magnitude += (double)typed_input[i] * (double)typed_input[i]; + magnitude = std::sqrt(magnitude); + for (std::size_t i = 0; i != dim; ++i) + typed_output[i] = + static_cast(usearch::clamp(typed_input[i] * 127.0 / magnitude, -127.0, 127.0)); + return true; + } +}; + +template struct cast_from_i8_gt { + static bool try_(byte_t const* input, std::size_t dim, byte_t* output) noexcept { + std::int8_t const* typed_input = reinterpret_cast(input); + to_scalar_at* typed_output = reinterpret_cast(output); + for (std::size_t i = 0; i != dim; ++i) + typed_output[i] = static_cast(typed_input[i]) / 127.f; + return true; + } +}; + +template <> struct cast_gt : public cast_from_i8_gt {}; +template <> struct cast_gt : public cast_from_i8_gt {}; +template <> struct cast_gt : public cast_from_i8_gt {}; +template <> struct cast_gt : public cast_from_i8_gt {}; + +template <> struct cast_gt : public cast_to_i8_gt {}; +template <> struct cast_gt : public cast_to_i8_gt {}; +template <> struct cast_gt : public cast_to_i8_gt {}; +template <> struct cast_gt : public cast_to_i8_gt {}; + +template <> struct cast_gt : public cast_from_b1x8_gt {}; +template <> struct cast_gt : public cast_from_b1x8_gt {}; +template <> struct cast_gt : public cast_from_b1x8_gt {}; +template <> struct cast_gt : public cast_from_b1x8_gt {}; + +template <> struct cast_gt : public cast_to_b1x8_gt {}; +template <> struct cast_gt : public cast_to_b1x8_gt {}; +template <> struct cast_gt : public cast_to_b1x8_gt {}; +template <> struct cast_gt : public cast_to_b1x8_gt {}; + +template <> struct cast_gt : public cast_from_b1x8_gt {}; +template <> struct cast_gt : public cast_to_b1x8_gt {}; + +/** + * @brief Type-punned array casting function. + * Arguments: input buffer, bytes in input buffer, output buffer. + * Returns `true` if the casting was performed successfully, `false` otherwise. + */ +using cast_punned_t = bool (*)(byte_t const*, std::size_t, byte_t*); + +/** + * @brief A collection of casting functions for typical vector types. + * Covers to/from conversions for boolean, integer, half-precision, + * single-precision, and double-precision scalars. + */ +struct casts_punned_t { + struct group_t { + cast_punned_t b1x8{}; + cast_punned_t i8{}; + cast_punned_t f16{}; + cast_punned_t bf16{}; + cast_punned_t f32{}; + cast_punned_t f64{}; + + cast_punned_t operator[](scalar_kind_t scalar_kind) const noexcept { + switch (scalar_kind) { + case scalar_kind_t::f64_k: return f64; + case scalar_kind_t::f32_k: return f32; + case scalar_kind_t::f16_k: return f16; + case scalar_kind_t::bf16_k: return bf16; + case scalar_kind_t::i8_k: return i8; + case scalar_kind_t::b1x8_k: return b1x8; + default: return nullptr; + } + } + + } from, to; + + template static casts_punned_t make() noexcept { + casts_punned_t result; + + result.from.b1x8 = &cast_gt::try_; + result.from.i8 = &cast_gt::try_; + result.from.f16 = &cast_gt::try_; + result.from.bf16 = &cast_gt::try_; + result.from.f32 = &cast_gt::try_; + result.from.f64 = &cast_gt::try_; + + result.to.b1x8 = &cast_gt::try_; + result.to.i8 = &cast_gt::try_; + result.to.f16 = &cast_gt::try_; + result.to.bf16 = &cast_gt::try_; + result.to.f32 = &cast_gt::try_; + result.to.f64 = &cast_gt::try_; + + return result; + } + + static casts_punned_t make(scalar_kind_t scalar_kind) noexcept { + switch (scalar_kind) { + case scalar_kind_t::f64_k: return casts_punned_t::make(); + case scalar_kind_t::f32_k: return casts_punned_t::make(); + case scalar_kind_t::f16_k: return casts_punned_t::make(); + case scalar_kind_t::bf16_k: return casts_punned_t::make(); + case scalar_kind_t::i8_k: return casts_punned_t::make(); + case scalar_kind_t::b1x8_k: return casts_punned_t::make(); + default: return {}; + } + } +}; + +/* Don't complain if the vectorization of the inner loops fails: + * + * > warning: loop not vectorized: the optimizer was unable to perform the requested transformation; + * > the transformation might be disabled or specified as part of an unsupported transformation ordering + */ +#if defined(USEARCH_DEFINED_CLANG) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif + +/** + * @brief Inner (Dot) Product distance. + * Vectors should be normalized to unit length, + * otherwise `::metric_cos_gt` should be used instead. + */ +template struct metric_ip_gt { + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : ab) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; ++i) + ab += result_t(a[i]) * result_t(b[i]); + return 1 - ab; + } +}; + +/** + * @brief Cosine (Angular) distance. + * Identical to the Inner Product of normalized vectors. + * Unless you are running on an tiny embedded platform, this metric + * is recommended over `::metric_ip_gt` for low-precision scalars. + */ +template struct metric_cos_gt { + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab{}, a2{}, b2{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : ab, a2, b2) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + ab += ai * bi, a2 += square(ai), b2 += square(bi); + } + + result_t result_if_zero[2][2]; + result_if_zero[0][0] = 1 - ab / (std::sqrt(a2) * std::sqrt(b2)); + result_if_zero[0][1] = result_if_zero[1][0] = 1; + result_if_zero[1][1] = 0; + return result_if_zero[a2 == 0][b2 == 0]; + } +}; + +/** + * @brief Squared Euclidean (L2) distance. + * Square root is avoided at the end, as it won't affect the ordering. + */ +template struct metric_l2sq_gt { + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + result_t ab_deltas_sq{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : ab_deltas_sq) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + ab_deltas_sq += square(ai - bi); + } + return ab_deltas_sq; + } +}; + +/** + * @brief Hamming distance computes the number of differing bits in + * two arrays of integers. An example would be a textual document, + * tokenized and hashed into a fixed-capacity bitset. + */ +template struct metric_hamming_gt { + using scalar_t = scalar_at; + using result_t = result_at; + static_assert( // + std::is_unsigned::value || + (std::is_enum::value && std::is_unsigned::type>::value), + "Hamming distance requires unsigned integral words"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { + constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; + result_t matches{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : matches) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != words; ++i) + matches += std::bitset(a[i] ^ b[i]).count(); + return matches; + } +}; + +/** + * @brief Tanimoto distance is the intersection over bitwise union. + * Often used in chemistry and biology to compare molecular fingerprints. + */ +template struct metric_tanimoto_gt { + using scalar_t = scalar_at; + using result_t = result_at; + static_assert( // + std::is_unsigned::value || + (std::is_enum::value && std::is_unsigned::type>::value), + "Tanimoto distance requires unsigned integral words"); + static_assert(std::is_floating_point::value, "Tanimoto distance will be a fraction"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { + constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; + result_t and_count{}; + result_t or_count{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : and_count, or_count) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != words; ++i) { + and_count += std::bitset(a[i] & b[i]).count(); + or_count += std::bitset(a[i] | b[i]).count(); + } + return 1 - result_t(and_count) / or_count; + } +}; + +/** + * @brief Sorensen-Dice or F1 distance is the intersection over bitwise union. + * Often used in chemistry and biology to compare molecular fingerprints. + */ +template struct metric_sorensen_gt { + using scalar_t = scalar_at; + using result_t = result_at; + static_assert( // + std::is_unsigned::value || + (std::is_enum::value && std::is_unsigned::type>::value), + "Sorensen-Dice distance requires unsigned integral words"); + static_assert(std::is_floating_point::value, "Sorensen-Dice distance will be a fraction"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t words) const noexcept { + constexpr std::size_t bits_per_word_k = sizeof(scalar_t) * CHAR_BIT; + result_t and_count{}; + result_t any_count{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : and_count, any_count) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != words; ++i) { + and_count += std::bitset(a[i] & b[i]).count(); + any_count += std::bitset(a[i]).count() + std::bitset(b[i]).count(); + } + return 1 - 2 * result_t(and_count) / any_count; + } +}; + +/** + * @brief Counts the number of matching elements in two unique sorted sets. + * Can be used to compute the similarity between two textual documents + * using the IDs of tokens present in them. + * Similar to `metric_tanimoto_gt` for dense representations. + */ +template struct metric_jaccard_gt { + using scalar_t = scalar_at; + using result_t = result_at; + static_assert(!std::is_floating_point::value, "Jaccard distance requires integral scalars"); + static_assert(std::is_floating_point::value, "Jaccard distance returns a fraction"); + + inline result_t operator()( // + scalar_t const* a, scalar_t const* b, std::size_t a_length, std::size_t b_length) const noexcept { + std::size_t intersection{}; + std::size_t i{}; + std::size_t j{}; + while (i != a_length && j != b_length) { + scalar_t ai = a[i]; + scalar_t bj = b[j]; + intersection += ai == bj; + i += ai < bj; + j += ai >= bj; + } + return 1 - static_cast(intersection) / (a_length + b_length - intersection); + } +}; + +/** + * @brief Measures Pearson Correlation between two sequences in a single pass. + */ +template struct metric_pearson_gt { + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t dim) const noexcept { + // The correlation coefficient can't be defined for one or zero-dimensional data. + if (dim <= 1) + return 0; + // Conventional Pearson Correlation Coefficient definiton subtracts the mean value of each + // sequence from each element, before dividing them. WikiPedia article suggests a convenient + // single-pass algorithm for calculating sample correlations, though depending on the numbers + // involved, it can sometimes be numerically unstable. + result_t a_sum{}, b_sum{}, ab_sum{}; + result_t a_sq_sum{}, b_sq_sum{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : a_sum, b_sum, ab_sum, a_sq_sum, b_sq_sum) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; ++i) { + result_t ai = static_cast(a[i]); + result_t bi = static_cast(b[i]); + a_sum += ai; + b_sum += bi; + ab_sum += ai * bi; + a_sq_sum += ai * ai; + b_sq_sum += bi * bi; + } + result_t denom = (dim * a_sq_sum - a_sum * a_sum) * (dim * b_sq_sum - b_sum * b_sum); + if (denom == 0) + return 0; + result_t corr = dim * ab_sum - a_sum * b_sum; + denom = std::sqrt(denom); + // The normal Pearson correlation value is between -1 and 1, but we are looking for a distance. + // So instead of returning `corr / denom`, we return `1 - corr / denom`. + return 1 - corr / denom; + } +}; + +/** + * @brief Measures Jensen-Shannon Divergence between two probability distributions. + */ +template struct metric_divergence_gt { + using scalar_t = scalar_at; + using result_t = result_at; + + inline result_t operator()(scalar_t const* p, scalar_t const* q, std::size_t dim) const noexcept { + result_t kld_pm{}, kld_qm{}; + result_t epsilon = std::numeric_limits::epsilon(); +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : kld_pm, kld_qm) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; ++i) { + result_t pi = static_cast(p[i]); + result_t qi = static_cast(q[i]); + result_t mi = (pi + qi) / 2 + epsilon; + kld_pm += pi * std::log((pi + epsilon) / mi); + kld_qm += qi * std::log((qi + epsilon) / mi); + } + return (kld_pm + kld_qm) / 2; + } +}; + +/** + * @brief Cosine (Angular) distance for signed 8-bit integers using 16-bit intermediates. + */ +struct metric_cos_i8_t { + using scalar_t = i8_t; + using result_t = f32_t; + + inline result_t operator()(i8_t const* a, i8_t const* b, std::size_t dim) const noexcept { + std::int32_t ab{}, a2{}, b2{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : ab, a2, b2) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; i++) { + std::int16_t ai{a[i]}; + std::int16_t bi{b[i]}; + ab += ai * bi; + a2 += square(ai); + b2 += square(bi); + } + result_t a2f = std::sqrt(static_cast(a2)); + result_t b2f = std::sqrt(static_cast(b2)); + return (ab != 0) ? (1.f - ab / (a2f * b2f)) : 0; + } +}; + +/** + * @brief Squared Euclidean (L2) distance for signed 8-bit integers using 16-bit intermediates. + * Square root is avoided at the end, as it won't affect the ordering. + */ +struct metric_l2sq_i8_t { + using scalar_t = i8_t; + using result_t = f32_t; + + inline result_t operator()(i8_t const* a, i8_t const* b, std::size_t dim) const noexcept { + std::int32_t ab_deltas_sq{}; +#if USEARCH_USE_OPENMP +#pragma omp simd reduction(+ : ab_deltas_sq) +#elif defined(USEARCH_DEFINED_CLANG) +#pragma clang loop vectorize(enable) +#elif defined(USEARCH_DEFINED_GCC) +#pragma GCC ivdep +#endif + for (std::size_t i = 0; i != dim; i++) + ab_deltas_sq += square(std::int16_t(a[i]) - std::int16_t(b[i])); + return static_cast(ab_deltas_sq); + } +}; + +/** + * @brief Haversine distance for the shortest distance between two nodes on + * the surface of a 3D sphere, defined with latitude and longitude. + */ +template struct metric_haversine_gt { + using scalar_t = scalar_at; + using result_t = result_at; + static_assert(!std::is_integral::value && !std::is_same::value, + "Latitude and longitude must be floating-node"); + + inline result_t operator()(scalar_t const* a, scalar_t const* b, std::size_t = 2) const noexcept { + result_t lat_a = a[0], lon_a = a[1]; + result_t lat_b = b[0], lon_b = b[1]; + + result_t lat_delta = angle_to_radians(lat_b - lat_a) / 2; + result_t lon_delta = angle_to_radians(lon_b - lon_a) / 2; + + result_t converted_lat_a = angle_to_radians(lat_a); + result_t converted_lat_b = angle_to_radians(lat_b); + + result_t x = square(std::sin(lat_delta)) + // + std::cos(converted_lat_a) * std::cos(converted_lat_b) * square(std::sin(lon_delta)); + + return 2 * std::asin(std::sqrt(x)); + } +}; + +using distance_punned_t = float; +using span_punned_t = span_gt; + +/** + * @brief The signature of the user-defined function. + * Can be just two array pointers, precompiled for a specific array length, + * or include one or two array sizes as 64-bit unsigned integers. + */ +enum class metric_punned_signature_t { + array_array_k = 0, + array_array_size_k, + array_array_state_k, +}; + +/** + * @brief Type-punned metric class, which unlike STL's `std::function` avoids any memory allocations. + * It also provides additional APIs to check, if SIMD hardware-acceleration is available. + * Wraps the `simsimd_metric_dense_punned_t` when available. The auto-vectorized backend otherwise. + */ +class metric_punned_t { + public: + using scalar_t = byte_t; + using result_t = distance_punned_t; + + private: + /// In the generalized function API all the are arguments are pointer-sized. + using uptr_t = std::size_t; + /// Distance function that takes two arrays and returns a scalar. + using metric_array_array_t = result_t (*)(uptr_t, uptr_t); + /// Distance function that takes two arrays and their length and returns a scalar. + using metric_array_array_size_t = result_t (*)(uptr_t, uptr_t, uptr_t); + /// Distance function that takes two arrays and some callback state and returns a scalar. + using metric_array_array_state_t = result_t (*)(uptr_t, uptr_t, uptr_t); + /// Distance function callback, like `metric_array_array_size_t`, but depends on member variables. + using metric_routed_t = result_t (metric_punned_t::*)(uptr_t, uptr_t) const; + + metric_routed_t metric_routed_ = nullptr; + uptr_t metric_ptr_ = 0; + uptr_t metric_third_arg_ = 0; + + std::size_t dimensions_ = 0; + metric_kind_t metric_kind_ = metric_kind_t::unknown_k; + scalar_kind_t scalar_kind_ = scalar_kind_t::unknown_k; + +#if USEARCH_USE_SIMSIMD + simsimd_capability_t isa_kind_ = simsimd_cap_serial_k; +#endif + + public: + /** + * @brief Computes the distance between two vectors of fixed length. + * + * ! This is the only relevant function in the object. Everything else is just dynamic dispatch logic. + */ + inline result_t operator()(byte_t const* a, byte_t const* b) const noexcept { + return (this->*metric_routed_)(reinterpret_cast(a), reinterpret_cast(b)); + } + + inline metric_punned_t() noexcept = default; + inline metric_punned_t(metric_punned_t const&) noexcept = default; + inline metric_punned_t& operator=(metric_punned_t const&) noexcept = default; + + inline metric_punned_t(std::size_t dimensions, metric_kind_t metric_kind = metric_kind_t::l2sq_k, + scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept + : metric_punned_t(builtin(dimensions, metric_kind, scalar_kind)) {} + + inline metric_punned_t(std::size_t dimensions, std::uintptr_t metric_uintptr, metric_punned_signature_t signature, + metric_kind_t metric_kind, scalar_kind_t scalar_kind) noexcept + : metric_punned_t(stateless(dimensions, metric_uintptr, signature, metric_kind, scalar_kind)) {} + + /** + * @brief Creates a metric of a natively supported kind, choosing the best + * available backend internally or from SimSIMD. + * + * @param dimensions The number of elements in the input arrays. + * @param metric_kind The kind of metric to use. + * @param scalar_kind The kind of scalar to use. + * @return A metric object that can be used to compute distances between vectors. + */ + inline static metric_punned_t builtin(std::size_t dimensions, metric_kind_t metric_kind = metric_kind_t::l2sq_k, + scalar_kind_t scalar_kind = scalar_kind_t::f32_k) noexcept { + metric_punned_t metric; + metric.metric_routed_ = &metric_punned_t::invoke_array_array_third; + metric.metric_ptr_ = 0; + metric.metric_third_arg_ = + scalar_kind == scalar_kind_t::b1x8_k ? divide_round_up(dimensions) : dimensions; + metric.dimensions_ = dimensions; + metric.metric_kind_ = metric_kind; + metric.scalar_kind_ = scalar_kind; + +#if USEARCH_USE_SIMSIMD + if (!metric.configure_with_simsimd()) + metric.configure_with_autovec(); +#else + metric.configure_with_autovec(); +#endif + + return metric; + } + + /** + * @brief Creates a metric using the provided function pointer for a stateless metric. + * So the provided ::metric_uintptr is a pointer to a function that takes two arrays + * and returns a scalar. If the ::signature is metric_punned_signature_t::array_array_size_k, + * then the third argument is the number of scalar words in the input vectors. + * + * @param dimensions The number of elements in the input arrays. + * @param metric_uintptr The function pointer to the metric function. + * @param signature The signature of the metric function. + * @param metric_kind The kind of metric to use. + * @param scalar_kind The kind of scalar to use. + * @return A metric object that can be used to compute distances between vectors. + */ + inline static metric_punned_t stateless(std::size_t dimensions, std::uintptr_t metric_uintptr, + metric_punned_signature_t signature, metric_kind_t metric_kind, + scalar_kind_t scalar_kind) noexcept { + metric_punned_t metric; + metric.metric_routed_ = signature == metric_punned_signature_t::array_array_k + ? &metric_punned_t::invoke_array_array + : &metric_punned_t::invoke_array_array_third; + metric.metric_ptr_ = metric_uintptr; + metric.metric_third_arg_ = + scalar_kind == scalar_kind_t::b1x8_k ? divide_round_up(dimensions) : dimensions; + metric.dimensions_ = dimensions; + metric.metric_kind_ = metric_kind; + metric.scalar_kind_ = scalar_kind; + return metric; + } + + /** + * @brief Creates a metric using the provided function pointer for a stateful metric. + * The third argument is the state that will be passed to the metric function. + * + * @param dimensions The number of elements in the input arrays. + * @param metric_uintptr The function pointer to the metric function. + * @param metric_state The state to pass to the metric function. + * @param metric_kind The kind of metric to use. + * @param scalar_kind The kind of scalar to use. + * @return A metric object that can be used to compute distances between vectors. + */ + inline static metric_punned_t stateful( // + std::size_t dimensions, std::uintptr_t metric_uintptr, std::uintptr_t metric_state, + metric_kind_t metric_kind = metric_kind_t::unknown_k, + scalar_kind_t scalar_kind = scalar_kind_t::unknown_k) noexcept { + metric_punned_t metric; + metric.metric_routed_ = &metric_punned_t::invoke_array_array_third; + metric.metric_ptr_ = metric_uintptr; + metric.metric_third_arg_ = metric_state; + metric.dimensions_ = dimensions; + metric.metric_kind_ = metric_kind; + metric.scalar_kind_ = scalar_kind; + return metric; + } + + inline std::size_t dimensions() const noexcept { return dimensions_; } + inline metric_kind_t metric_kind() const noexcept { return metric_kind_; } + inline scalar_kind_t scalar_kind() const noexcept { return scalar_kind_; } + inline explicit operator bool() const noexcept { return metric_routed_ && metric_ptr_; } + + /** + * @brief Checks if we've failed to initialize the metric with provided arguments. + * + * It's different from `operator bool()` when it comes to explicitly uninitialized metrics. + * It's a common case, where a NULL state is created only to be overwritten later, when + * we recover an old index state from a file or a network. + */ + inline bool missing() const noexcept { return !bool(*this) && metric_kind_ != metric_kind_t::unknown_k; } + + inline char const* isa_name() const noexcept { + if (!*this) + return "uninitialized"; + +#if USEARCH_USE_SIMSIMD + switch (isa_kind_) { + case simsimd_cap_serial_k: return "serial"; + case simsimd_cap_neon_k: return "neon"; + case simsimd_cap_neon_i8_k: return "neon_i8"; + case simsimd_cap_neon_f16_k: return "neon_f16"; + case simsimd_cap_neon_bf16_k: return "neon_bf16"; + case simsimd_cap_sve_k: return "sve"; + case simsimd_cap_sve_i8_k: return "sve_i8"; + case simsimd_cap_sve_f16_k: return "sve_f16"; + case simsimd_cap_sve_bf16_k: return "sve_bf16"; + case simsimd_cap_haswell_k: return "haswell"; + case simsimd_cap_skylake_k: return "skylake"; + case simsimd_cap_ice_k: return "ice"; + case simsimd_cap_genoa_k: return "genoa"; + case simsimd_cap_sapphire_k: return "sapphire"; + default: return "unknown"; + } +#endif + return "serial"; + } + + inline std::size_t bytes_per_vector() const noexcept { + return divide_round_up(dimensions_ * bits_per_scalar(scalar_kind_)); + } + + inline std::size_t scalar_words() const noexcept { + return divide_round_up(dimensions_ * bits_per_scalar(scalar_kind_), bits_per_scalar_word(scalar_kind_)); + } + + private: +#if USEARCH_USE_SIMSIMD + bool configure_with_simsimd(simsimd_capability_t simd_caps) noexcept { + simsimd_metric_kind_t kind = simsimd_metric_unknown_k; + simsimd_datatype_t datatype = simsimd_datatype_unknown_k; + simsimd_capability_t allowed = simsimd_cap_any_k; + switch (metric_kind_) { + case metric_kind_t::ip_k: kind = simsimd_metric_dot_k; break; + case metric_kind_t::cos_k: kind = simsimd_metric_cos_k; break; + case metric_kind_t::l2sq_k: kind = simsimd_metric_l2sq_k; break; + case metric_kind_t::hamming_k: kind = simsimd_metric_hamming_k; break; + case metric_kind_t::tanimoto_k: kind = simsimd_metric_jaccard_k; break; + case metric_kind_t::jaccard_k: kind = simsimd_metric_jaccard_k; break; + default: break; + } + switch (scalar_kind_) { + case scalar_kind_t::f32_k: datatype = simsimd_datatype_f32_k; break; + case scalar_kind_t::f64_k: datatype = simsimd_datatype_f64_k; break; + case scalar_kind_t::f16_k: datatype = simsimd_datatype_f16_k; break; + case scalar_kind_t::bf16_k: datatype = simsimd_datatype_bf16_k; break; + case scalar_kind_t::i8_k: datatype = simsimd_datatype_i8_k; break; + case scalar_kind_t::b1x8_k: datatype = simsimd_datatype_b8_k; break; + default: break; + } + simsimd_metric_dense_punned_t simd_metric = NULL; + simsimd_capability_t simd_kind = simsimd_cap_any_k; + simsimd_find_kernel_punned(kind, datatype, simd_caps, allowed, (simsimd_kernel_punned_t*)&simd_metric, + &simd_kind); + if (simd_metric == nullptr) + return false; + + std::memcpy(&metric_ptr_, &simd_metric, sizeof(simd_metric)); + metric_routed_ = metric_kind_ == metric_kind_t::ip_k + ? reinterpret_cast(&metric_punned_t::invoke_simsimd_reverse) + : reinterpret_cast(&metric_punned_t::invoke_simsimd); + isa_kind_ = simd_kind; + return true; + } + bool configure_with_simsimd() noexcept { + static simsimd_capability_t static_capabilities = simsimd_capabilities(); + return configure_with_simsimd(static_capabilities); + } + +#if defined(USEARCH_DEFINED_CLANG) || defined(USEARCH_DEFINED_GCC) + __attribute__((no_sanitize("all"))) +#endif + result_t + invoke_simsimd(uptr_t a, uptr_t b) const noexcept { + simsimd_distance_t result; + // Here `reinterpret_cast` raises warning and UBSan reports an issue... we know what we are doing! + auto function_pointer = (simsimd_metric_dense_punned_t)(metric_ptr_); + function_pointer(reinterpret_cast(a), reinterpret_cast(b), metric_third_arg_, + &result); + return (result_t)result; + } + result_t invoke_simsimd_reverse(uptr_t a, uptr_t b) const noexcept { return 1 - invoke_simsimd(a, b); } +#else + bool configure_with_simsimd() noexcept { return false; } +#endif + result_t invoke_array_array_third(uptr_t a, uptr_t b) const noexcept { + auto function_pointer = (metric_array_array_size_t)(metric_ptr_); + result_t result = function_pointer(a, b, metric_third_arg_); + return result; + } + result_t invoke_array_array(uptr_t a, uptr_t b) const noexcept { + auto function_pointer = (metric_array_array_t)(metric_ptr_); + result_t result = function_pointer(a, b); + return result; + } + void configure_with_autovec() noexcept { + switch (metric_kind_) { + case metric_kind_t::ip_k: { + switch (scalar_kind_) { + case scalar_kind_t::bf16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::cos_k: { + switch (scalar_kind_) { + case scalar_kind_t::bf16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::l2sq_k: { + switch (scalar_kind_) { + case scalar_kind_t::bf16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::pearson_k: { + switch (scalar_kind_) { + case scalar_kind_t::bf16_k: + metric_ptr_ = (uptr_t)&equidimensional_>; + break; + case scalar_kind_t::i8_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f16_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::haversine_k: { + switch (scalar_kind_) { + case scalar_kind_t::bf16_k: metric_ptr_ = 0; break; //< Half-precision 2D vectors are silly. + case scalar_kind_t::f16_k: metric_ptr_ = 0; break; //< Half-precision 2D vectors are silly. + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::divergence_k: { + switch (scalar_kind_) { + case scalar_kind_t::bf16_k: + metric_ptr_ = (uptr_t)&equidimensional_>; + break; + case scalar_kind_t::f16_k: + metric_ptr_ = (uptr_t)&equidimensional_>; + break; + case scalar_kind_t::f32_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case scalar_kind_t::f64_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: metric_ptr_ = 0; break; + } + break; + } + case metric_kind_t::jaccard_k: // Equivalent to Tanimoto + case metric_kind_t::tanimoto_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case metric_kind_t::hamming_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + case metric_kind_t::sorensen_k: metric_ptr_ = (uptr_t)&equidimensional_>; break; + default: return; + } + } + + template + inline static result_t equidimensional_(uptr_t a, uptr_t b, uptr_t a_dimensions) noexcept { + using scalar_t = typename typed_at::scalar_t; + return static_cast(typed_at{}((scalar_t const*)a, (scalar_t const*)b, a_dimensions)); + } +}; + +/* Allow complaining about vectorization after this point. */ +#if defined(USEARCH_DEFINED_CLANG) +#pragma clang diagnostic pop +#endif + +/** + * @brief View over a potentially-strided memory buffer, containing a row-major matrix. + */ +template // +class matrix_slice_gt { + using scalar_t = scalar_at; + using byte_addressable_t = typename std::conditional::value, byte_t const, byte_t>::type; + + scalar_t* begin_{}; + std::size_t dimensions_{}; + std::size_t count_{}; + std::size_t stride_bytes_{}; + + public: + matrix_slice_gt() noexcept = default; + matrix_slice_gt(matrix_slice_gt const&) noexcept = default; + matrix_slice_gt& operator=(matrix_slice_gt const&) noexcept = default; + + matrix_slice_gt(scalar_t* begin, std::size_t dimensions, std::size_t count = 1) noexcept + : matrix_slice_gt(begin, dimensions, count, dimensions * sizeof(scalar_at)) {} + + matrix_slice_gt(scalar_t* begin, std::size_t dimensions, std::size_t count, std::size_t stride_bytes) noexcept + : begin_(begin), dimensions_(dimensions), count_(count), stride_bytes_(stride_bytes) {} + + explicit operator bool() const noexcept { return begin_; } + std::size_t size() const noexcept { return count_; } + std::size_t dimensions() const noexcept { return dimensions_; } + std::size_t stride_bytes() const noexcept { return stride_bytes_; } + scalar_t* data() const noexcept { return begin_; } + scalar_t* at(std::size_t i) const noexcept { + return reinterpret_cast(reinterpret_cast(begin_) + i * stride_bytes_); + } +}; + +struct exact_offset_and_distance_t { + u32_t offset; + f32_t distance; +}; + +using exact_search_results_t = matrix_slice_gt; + +/** + * @brief Helper-structure for exact search operations. + * Perfect if you have @b <1M vectors and @b <100 queries per call. + * + * Uses a 3-step procedure to minimize: + * - cache-misses on vector lookups, + * - multi-threaded contention on concurrent writes. + */ +class exact_search_t { + + inline static bool smaller_distance(exact_offset_and_distance_t a, exact_offset_and_distance_t b) noexcept { + return a.distance < b.distance; + } + + using keys_and_distances_t = buffer_gt; + keys_and_distances_t keys_and_distances; + + public: + template + exact_search_results_t operator()( // + matrix_slice_gt dataset, matrix_slice_gt queries, // + std::size_t wanted, metric_punned_t const& metric, // + executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + return operator()( // + metric, // + reinterpret_cast(dataset.data()), dataset.size(), dataset.stride_bytes(), // + reinterpret_cast(queries.data()), queries.size(), queries.stride_bytes(), // + wanted, executor, progress); + } + + template + exact_search_results_t operator()( // + byte_t const* dataset_data, std::size_t dataset_count, std::size_t dataset_stride, // + byte_t const* queries_data, std::size_t queries_count, std::size_t queries_stride, // + std::size_t wanted, metric_punned_t const& metric, executor_at&& executor = executor_at{}, + progress_at&& progress = progress_at{}) { + + // Allocate temporary memory to store the distance matrix + // Previous version didn't need temporary memory, but the performance was much lower. + // In the new design we keep two buffers - original and transposed, as in-place transpositions + // of non-rectangular matrixes is expensive. + std::size_t tasks_count = dataset_count * queries_count; + if (keys_and_distances.size() < tasks_count * 2) + keys_and_distances = keys_and_distances_t(tasks_count * 2); + if (keys_and_distances.size() < tasks_count * 2) + return {}; + + exact_offset_and_distance_t* keys_and_distances_per_dataset = keys_and_distances.data(); + exact_offset_and_distance_t* keys_and_distances_per_query = keys_and_distances_per_dataset + tasks_count; + + // §1. Compute distances in a data-parallel fashion + std::atomic processed{0}; + executor.dynamic(dataset_count, [&](std::size_t thread_idx, std::size_t dataset_idx) { + byte_t const* dataset = dataset_data + dataset_idx * dataset_stride; + for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx) { + byte_t const* query = queries_data + query_idx * queries_stride; + auto distance = metric(dataset, query); + std::size_t task_idx = queries_count * dataset_idx + query_idx; + keys_and_distances_per_dataset[task_idx].offset = static_cast(dataset_idx); + keys_and_distances_per_dataset[task_idx].distance = static_cast(distance); + } + + // It's more efficient in this case to report progress from a single thread + processed += queries_count; + if (thread_idx == 0) + if (!progress(processed.load(), tasks_count)) + return false; + return true; + }); + if (processed.load() != tasks_count) + return {}; + + // §2. Transpose in a single thread to avoid contention writing into the same memory buffers + for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx) { + for (std::size_t dataset_idx = 0; dataset_idx != dataset_count; ++dataset_idx) { + std::size_t from_idx = queries_count * dataset_idx + query_idx; + std::size_t to_idx = dataset_count * query_idx + dataset_idx; + keys_and_distances_per_query[to_idx] = keys_and_distances_per_dataset[from_idx]; + } + } + + // §3. Partial-sort every query result + executor.fixed(queries_count, [&](std::size_t, std::size_t query_idx) { + auto start = keys_and_distances_per_query + dataset_count * query_idx; + if (wanted > 1) { + // TODO: Consider alternative sorting approaches + // radix_sort(start, start + dataset_count, wanted); + // std::sort(start, start + dataset_count, &smaller_distance); + std::partial_sort(start, start + wanted, start + dataset_count, &smaller_distance); + } else { + auto min_it = std::min_element(start, start + dataset_count, &smaller_distance); + if (min_it != start) + std::swap(*min_it, *start); + } + }); + + // At the end report the latest numbers, because the reporter thread may be finished earlier + progress(tasks_count, tasks_count); + return {keys_and_distances_per_query, wanted, queries_count, + dataset_count * sizeof(exact_offset_and_distance_t)}; + } +}; + +struct kmeans_clustering_result_t { + error_t error{}; + std::size_t computed_distances{}; + /// @brief The number of iterations the algorithm took to converge. + std::size_t iterations{}; + /// @brief The number of points that changed clusters in the last iteration. + std::size_t last_iteration_points_shifted{}; + /// @brief The inertia of the last iteration (sum of squared distances to centroids). + f64_t last_iteration_inertia{}; + /// @brief The total elapsed runtime of the algorithm in seconds. + f64_t runtime_seconds{}; + /// @brief The total distance between the points and their assigned centroids. + f64_t aggregate_distance{}; + + explicit operator bool() const noexcept { return !error; } + kmeans_clustering_result_t failed(error_t message) noexcept { + error = std::move(message); + return std::move(*this); + } +}; + +/** + * @brief Helper-class for K-Means clustering of dense vectors. + * Doesn't require constructing the index, but benefits from mixed-precision logic. + * ! Doesn't guarantee that the clusters are balanced in size. + * + * The algorithm is as follows: + * - Initialization: Select K initial centroids (randomly or with a heuristic). + * - Assignment: Assign each data point to the nearest centroid based on the Euclidean distance. + * - Update: Recalculate the centroids as the mean of all points assigned to each centroid. + * - Repeat: Repeat the assignment and update steps until the centroids no longer change significantly + * or an early-exit condition is met. + */ +template > class kmeans_clustering_gt { + public: + using distance_t = distance_punned_t; + + metric_kind_t metric_kind{metric_kind_t::l2sq_k}; + scalar_kind_t quantization_kind{scalar_kind_t::bf16_k}; + + static constexpr std::size_t max_iterations_default_k = 300; + static constexpr f64_t inertia_threshold_default_k = 1e-4; + static constexpr f64_t max_seconds_default_k = 60.0; + static constexpr f64_t min_shifts_default_k = 0.01; + + /// @brief Early-exit parameter - the maximum number of iterations to perform. + std::size_t max_iterations{max_iterations_default_k}; + /// @brief Early-exit parameter - the threshold for the final inertia to terminate early. + f64_t inertia_threshold{inertia_threshold_default_k}; + /// @brief Early-exit parameter - the maximum runtime allowed in seconds. + f64_t max_seconds{max_seconds_default_k}; + /// @brief Early-exit parameter - the minimum share of points that must change clusters per iteration. + f64_t min_shifts{min_shifts_default_k}; + /// @brief The random seed to use for centroid initialization. + std::uint64_t seed{0}; + + kmeans_clustering_gt(std::uint64_t seed) noexcept : seed(seed) {} + kmeans_clustering_gt() noexcept(false) { + std::random_device random_device; + seed = random_device(); + } + + kmeans_clustering_gt(kmeans_clustering_gt const&) = default; + kmeans_clustering_gt& operator=(kmeans_clustering_gt const&) = default; + + template + kmeans_clustering_result_t operator()( // + matrix_slice_gt points, matrix_slice_gt centroids, + span_gt point_to_centroid_index, span_gt point_to_centroid_distance, // + executor_at&& executor = executor_at{}, progress_at&& progress = progress_at{}) { + return operator()( // + reinterpret_cast(points.data()), points.size(), points.stride_bytes(), // + reinterpret_cast(centroids.data()), centroids.size(), centroids.stride_bytes(), + point_to_centroid_index.data(), point_to_centroid_distance.data(), // + scalar_kind(), points.dimensions(), executor, progress); + } + + template + kmeans_clustering_result_t operator()( // + byte_t const* points_data, std::size_t points_count, std::size_t points_stride_bytes, // + byte_t* centroids_data, std::size_t wanted_clusters, std::size_t centroids_stride_bytes, // + std::size_t* point_to_centroid_index, distance_t* point_to_centroid_distance, // + scalar_kind_t original_scalar_kind, std::size_t dimensions, executor_at&& executor = executor_at{}, + progress_at&& progress = progress_at{}) { + + (void)progress; // TODO + + // Perform sanity checks for algorithm settings. + kmeans_clustering_result_t result; + if (max_iterations < 1) + return result.failed("The number of iterations must be at least 1"); + + // Perform sanity checks for input arguments. + if (wanted_clusters < 2) + return result.failed("The number of clusters must be at least 2"); + if (wanted_clusters >= points_count) + return result.failed("The number of clusters must be less than the number of vectors"); + + metric_punned_t metric = metric_punned_t::builtin(dimensions, metric_kind, quantization_kind); + if (!metric) + return result.failed("Unsupported metric or scalar kind"); + + // Let's allocate memory for the centroids coordinates and make sure it's + // rows are aligned to cache lines to avoid false sharing. + buffer_gt> point_to_centroid_distance_buffer(points_count); + buffer_gt> point_to_centroid_index_buffer(points_count); + buffer_gt, aligned_allocator_gt, 64>> cluster_sizes_buffer( + wanted_clusters); + + // For a mixed precision computation, we keep the centroids represented in two forms - + // double precision and quantized the same way as in the index, to avoid paying conversion penalties. + // Double precision is needed to avoid accumulating errors when aggregating too many entries. + std::size_t const bytes_per_vector_original = + divide_round_up(dimensions * bits_per_scalar(original_scalar_kind)); + std::size_t const bytes_per_vector_quantized = metric.bytes_per_vector(); + std::size_t const stride_per_vector_quantized = divide_round_up<64>(bytes_per_vector_quantized) * 64; + buffer_gt> points_quantized_buffer( // + points_count * stride_per_vector_quantized); + buffer_gt> centroids_quantized_buffer( // + wanted_clusters * stride_per_vector_quantized); + + // When aggregating centroids, we want to parallelize the operation and need more memory. + // For every thread we keep two double-precision vectors. One is the up-casting output buffer for quantized + // coordinates, and the other is the temporary buffer for the partial sums of the double-precision coordinates. + // The ordering: + // + // - thread 0: [centroid 0, centroid 1, centroid 2, centroid 3, ...] + // - thread 1: [centroid 0, centroid 1, centroid 2, centroid 3, ...] + // - thread 2: [centroid 0, centroid 1, centroid 2, centroid 3, ...] + // + std::size_t const thread_count = executor.size(); + buffer_gt> centroids_precise_buffer( // + wanted_clusters * dimensions * thread_count); + buffer_gt> points_precise_buffer( // + wanted_clusters * dimensions * thread_count); + + // Check if all memory allocations were successful. + if (!centroids_precise_buffer || !points_precise_buffer || !centroids_quantized_buffer || + !point_to_centroid_index_buffer || !cluster_sizes_buffer || !point_to_centroid_distance_buffer || + !points_quantized_buffer) + return result.failed("No memory for result outputs!"); + + std::fill_n(point_to_centroid_index_buffer.data(), points_count, wanted_clusters); + std::fill_n(point_to_centroid_distance_buffer.data(), points_count, std::numeric_limits::max()); + + // Initialize the casting kernel for quantization and export. + casts_punned_t casts = casts_punned_t::make(quantization_kind); + cast_punned_t const& compress_points = casts.from[original_scalar_kind]; + cast_punned_t const& decompress_points = casts.to[original_scalar_kind]; + cast_punned_t const& compress_precise = casts.from.f64; + cast_punned_t const& decompress_precise = casts.to.f64; + for (std::size_t i = 0; i < points_count; i++) { + byte_t const* vector = points_data + i * points_stride_bytes; + byte_t* quantized = points_quantized_buffer.data() + i * stride_per_vector_quantized; + if (!compress_points(vector, dimensions, quantized)) + std::memcpy(quantized, vector, bytes_per_vector_original); + } + + // Initialize centroids with random points vectors. + std::mt19937_64 random_engine; + random_engine.seed(seed); + for (std::size_t i = 0; i < wanted_clusters; i++) { + // Generate the random index of the points vector, + // that is unique and not already used as a centroid. + std::size_t random_index; + do { + random_index = random_engine() % points_count; + bool is_unique = true; + for (std::size_t j = 0; j < i; j++) { + if (point_to_centroid_index_buffer[j] == random_index) { + is_unique = false; + break; + } + } + if (is_unique) + break; + } while (true); + + // Copy the vector to the centroid and quantize it. + byte_t const* quantized_point = points_quantized_buffer.data() + random_index * stride_per_vector_quantized; + byte_t* quantized_centroid = centroids_quantized_buffer.data() + i * stride_per_vector_quantized; + std::memcpy(quantized_centroid, quantized_point, bytes_per_vector_quantized); + point_to_centroid_index_buffer[random_index] = i; + point_to_centroid_distance_buffer[random_index] = 0; + } + + auto start_time = std::chrono::high_resolution_clock::now(); + std::size_t iterations = 0; + std::size_t const min_points_shifted_per_iteration = static_cast(min_shifts * points_count); + f64_t last_aggregate_distance = std::numeric_limits::max(); + + while (iterations < max_iterations) { + iterations++; + + // For every point, find the closest centroid. + std::atomic points_shifted{0}; + executor.dynamic(points_count, [&](std::size_t, std::size_t points_idx) { + byte_t const* quantized_point = + points_quantized_buffer.data() + points_idx * stride_per_vector_quantized; + byte_t const* quantized_centroids = centroids_quantized_buffer.data(); + distance_t closest_distance_local = std::numeric_limits::max(); + std::size_t closest_idx_local = 0; + for (std::size_t centroid_idx = 0; centroid_idx < wanted_clusters; centroid_idx++) { + byte_t const* quantized_centroid = quantized_centroids + centroid_idx * stride_per_vector_quantized; + distance_t distance = metric(quantized_point, quantized_centroid); + if (distance < closest_distance_local) { + closest_distance_local = distance; + closest_idx_local = centroid_idx; + } + } + + distance_t& closest_distance_ref = point_to_centroid_distance_buffer[points_idx]; + std::size_t& closest_idx_ref = point_to_centroid_index_buffer[points_idx]; + if (closest_idx_local != closest_idx_ref) { + closest_idx_ref = closest_idx_local; + points_shifted.fetch_add(1, std::memory_order_relaxed); + } + + closest_distance_ref = closest_distance_local; + return true; + }); + + f64_t aggregate_distance = 0.0; + for (std::size_t i = 0; i < points_count; i++) + aggregate_distance += point_to_centroid_distance_buffer[i]; + f64_t aggregate_distance_change = + std::abs(aggregate_distance - last_aggregate_distance) / last_aggregate_distance; + + auto current_time = std::chrono::high_resolution_clock::now(); + std::chrono::duration elapsed_time = current_time - start_time; + result.runtime_seconds = elapsed_time.count(); + result.last_iteration_inertia = aggregate_distance_change; + result.last_iteration_points_shifted = points_shifted.load(std::memory_order_relaxed); + + // Check for early-exit conditions + if (last_aggregate_distance != 0.0 && inertia_threshold != 0.0) + if (aggregate_distance_change <= inertia_threshold) + break; + if (min_points_shifted_per_iteration != 0 || result.last_iteration_points_shifted == 0) + if (result.last_iteration_points_shifted <= min_points_shifted_per_iteration) + break; + if (max_seconds != 0) + if (result.runtime_seconds >= max_seconds) + break; + + // For every centroid, recalculate the mean of all points assigned to it. + // That part is problematic to parallelize on many-core-systems, because of the contention. + // Alternatively, a tree-like approach can be used, where every core accumulates it's own partial sums. + // And those are later aggregated by a single thread. + std::memset(centroids_precise_buffer.data(), 0, + wanted_clusters * dimensions * thread_count * sizeof(f64_t)); + std::memset(reinterpret_cast(cluster_sizes_buffer.data()), 0, + wanted_clusters * sizeof(std::atomic)); + executor.dynamic(points_count, [&](std::size_t thread_idx, std::size_t points_idx) { + std::size_t centroid_idx = point_to_centroid_index_buffer[points_idx]; + byte_t const* quantized_point = + points_quantized_buffer.data() + points_idx * stride_per_vector_quantized; + f64_t* centroid_precise = centroids_precise_buffer.data() + wanted_clusters * dimensions * thread_idx + + centroid_idx * dimensions; + + // Upcast the points point into a buffer of double-precision floats. + f64_t* point_precise = points_precise_buffer.data() + wanted_clusters * dimensions * thread_idx + + centroid_idx * dimensions; + if (!decompress_precise(quantized_point, dimensions, reinterpret_cast(point_precise))) + std::memcpy(reinterpret_cast(point_precise), quantized_point, bytes_per_vector_quantized); + + // Now add the vector from the points into the centroid partial sum. + for (std::size_t i = 0; i < dimensions; i++) + centroid_precise[i] += point_precise[i]; + + cluster_sizes_buffer[centroid_idx].fetch_add(1, std::memory_order_relaxed); + return true; + }); + + // Aggregate the partial sums into the final centroids - storing them in the high-precision + // buffer of the first thread. Normalization procedure is different for different metrics. + for (std::size_t centroid_idx = 0; centroid_idx < wanted_clusters; centroid_idx++) { + f64_t* centroid_precise_aggregated = centroids_precise_buffer.data() + centroid_idx * dimensions; + for (std::size_t thread_idx = 1; thread_idx < thread_count; thread_idx++) { + f64_t* centroid_precise = centroids_precise_buffer.data() + + wanted_clusters * dimensions * thread_idx + centroid_idx * dimensions; + for (std::size_t i = 0; i < dimensions; i++) + centroid_precise_aggregated[i] += centroid_precise[i]; + } + + // Normalize based on the metric kind + if (metric_kind == metric_kind_t::l2sq_k) { + // Normalize for Euclidean distance (L2) + std::size_t cluster_size = cluster_sizes_buffer[centroid_idx].load(std::memory_order_relaxed); + if (cluster_size > 0) + for (std::size_t i = 0; i < dimensions; i++) + centroid_precise_aggregated[i] /= static_cast(cluster_size); + + } else if (metric_kind == metric_kind_t::cos_k) { + // Normalize for Cosine distance + f64_t norm = 0.0; + for (std::size_t i = 0; i < dimensions; i++) + norm += centroid_precise_aggregated[i] * centroid_precise_aggregated[i]; + norm = std::sqrt(norm); + if (norm > 0.0) + for (std::size_t i = 0; i < dimensions; i++) + centroid_precise_aggregated[i] /= norm; + } + + // Quantize the centroid after normalization for further iterations + byte_t* centroid_quantized = + centroids_quantized_buffer.data() + centroid_idx * stride_per_vector_quantized; + if (!compress_precise(reinterpret_cast(centroid_precise_aggregated), dimensions, + centroid_quantized)) + std::memcpy(centroid_quantized, reinterpret_cast(centroid_precise_aggregated), + bytes_per_vector_quantized); + } + } + + // Export stats. + result.iterations = iterations; + result.computed_distances = points_count * wanted_clusters * iterations; + result.aggregate_distance = 0; + for (distance_t distance : point_to_centroid_distance_buffer) + result.aggregate_distance += distance; + + // We've finished all the iterations, now we can export the centroids back to the original precision. + std::memcpy(point_to_centroid_index, point_to_centroid_index_buffer.data(), points_count * sizeof(std::size_t)); + std::memcpy(point_to_centroid_distance, point_to_centroid_distance_buffer.data(), + points_count * sizeof(distance_t)); + for (std::size_t i = 0; i < wanted_clusters; i++) { + byte_t const* quantized_centroid = centroids_quantized_buffer.data() + i * stride_per_vector_quantized; + byte_t* centroid = centroids_data + i * centroids_stride_bytes; + if (!decompress_points(quantized_centroid, dimensions, centroid)) + std::memcpy(centroid, quantized_centroid, bytes_per_vector_quantized); + } + + return result; + } +}; + +using kmeans_clustering_t = kmeans_clustering_gt<>; + +/** + * @brief C++11 Multi-Hash-Set with Linear Probing. + * + * - Allows multiple equivalent values, + * - Supports transparent hashing and equality operator. + * - Doesn't throw exceptions, if forbidden. + * - Doesn't need reserving a value for deletions. + * + * @section Layout + * + * For every slot we store 2 extra bits for 3 possible states: empty, populated, or deleted. + * With linear probing the hashes at the end of the populated region will spill into its first half. + */ +template > +class flat_hash_multi_set_gt { + public: + using element_t = element_at; + using hash_t = hash_at; + using equals_t = equals_at; + using allocator_t = allocator_at; + + static constexpr std::size_t slots_per_bucket() { return 64; } + static constexpr std::size_t bytes_per_bucket() { + return slots_per_bucket() * sizeof(element_t) + sizeof(bucket_header_t); + } + + private: + struct bucket_header_t { + std::uint64_t populated{}; + std::uint64_t deleted{}; + }; + char* data_ = nullptr; + std::size_t buckets_ = 0; + std::size_t populated_slots_ = 0; + /// @brief Number of slots + std::size_t capacity_slots_ = 0; + + struct slot_ref_t { + bucket_header_t& header; + std::uint64_t mask; + element_t& element; + }; + + slot_ref_t slot_ref(char* data, std::size_t slot_index) const noexcept { + std::size_t bucket_index = slot_index / slots_per_bucket(); + std::size_t in_bucket_index = slot_index % slots_per_bucket(); + auto bucket_pointer = data + bytes_per_bucket() * bucket_index; + auto slot_pointer = bucket_pointer + sizeof(bucket_header_t) + sizeof(element_t) * in_bucket_index; + return { + *reinterpret_cast(bucket_pointer), + static_cast(1ull) << in_bucket_index, + *reinterpret_cast(slot_pointer), + }; + } + + slot_ref_t slot_ref(std::size_t slot_index) const noexcept { return slot_ref(data_, slot_index); } + + bool populate_slot(slot_ref_t slot, element_t const& new_element) { + if (slot.header.populated & slot.mask) { + slot.element = new_element; + slot.header.deleted &= ~slot.mask; + return false; + } else { + new (&slot.element) element_t(new_element); + slot.header.populated |= slot.mask; + return true; + } + } + + public: + std::size_t size() const noexcept { return populated_slots_; } + std::size_t capacity() const noexcept { return capacity_slots_; } + + flat_hash_multi_set_gt() noexcept {} + ~flat_hash_multi_set_gt() noexcept { reset(); } + + flat_hash_multi_set_gt(flat_hash_multi_set_gt const& other) { + + // On Windows allocating a zero-size array would fail + if (!other.buckets_) { + reset(); + return; + } + + // Allocate new memory + data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); + if (!data_) + usearch_raise_runtime_error("failed memory allocation"); + + // Copy metadata + buckets_ = other.buckets_; + populated_slots_ = other.populated_slots_; + capacity_slots_ = other.capacity_slots_; + + // Initialize new buckets to empty + std::memset(data_, 0, buckets_ * bytes_per_bucket()); + + // Copy elements and bucket headers + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t old_slot = other.slot_ref(i); + if ((old_slot.header.populated & old_slot.mask) && !(old_slot.header.deleted & old_slot.mask)) { + slot_ref_t new_slot = slot_ref(i); + populate_slot(new_slot, old_slot.element); + } + } + } + + flat_hash_multi_set_gt& operator=(flat_hash_multi_set_gt const& other) { + + // On Windows allocating a zero-size array would fail + if (!other.buckets_) { + reset(); + return *this; + } + + // Handle self-assignment + if (this == &other) + return *this; + + // Clear existing data + clear(); + if (data_) + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); + + // Allocate new memory + data_ = (char*)allocator_t{}.allocate(other.buckets_ * bytes_per_bucket()); + if (!data_) + usearch_raise_runtime_error("failed memory allocation"); + + // Copy metadata + buckets_ = other.buckets_; + populated_slots_ = other.populated_slots_; + capacity_slots_ = other.capacity_slots_; + + // Initialize new buckets to empty + std::memset(data_, 0, buckets_ * bytes_per_bucket()); + + // Copy elements and bucket headers + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t old_slot = other.slot_ref(i); + if ((old_slot.header.populated & old_slot.mask) && !(old_slot.header.deleted & old_slot.mask)) { + slot_ref_t new_slot = slot_ref(i); + populate_slot(new_slot, old_slot.element); + } + } + + return *this; + } + + void clear() noexcept { + // Call the destructors + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t slot = slot_ref(i); + if ((slot.header.populated & slot.mask) & (~slot.header.deleted & slot.mask)) + slot.element.~element_t(); + } + + // Reset populated slots count + if (data_) + std::memset(data_, 0, buckets_ * bytes_per_bucket()); + populated_slots_ = 0; + } + + void reset() noexcept { + clear(); // Clear all elements + if (data_) + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); + data_ = nullptr; + buckets_ = 0; + populated_slots_ = 0; + capacity_slots_ = 0; + } + + bool try_reserve(std::size_t capacity) noexcept { + if (capacity * 3u <= capacity_slots_ * 2u) + return true; + + // Calculate new sizes + std::size_t new_slots = ceil2((capacity * 3ul) / 2ul); + std::size_t new_buckets = divide_round_up(new_slots); + new_slots = new_buckets * slots_per_bucket(); // This must be a power of two! + std::size_t new_bytes = new_buckets * bytes_per_bucket(); + + // Allocate new memory + char* new_data = (char*)allocator_t{}.allocate(new_bytes); + if (!new_data) + return false; + + // Initialize new buckets to empty + std::memset(new_data, 0, new_bytes); + + // Rehash and copy existing elements to new_data + hash_t hasher; + for (std::size_t i = 0; i < capacity_slots_; ++i) { + slot_ref_t old_slot = slot_ref(i); + if ((~old_slot.header.populated & old_slot.mask) | (old_slot.header.deleted & old_slot.mask)) + continue; + + // Rehash + std::size_t hash_value = hasher(old_slot.element); + std::size_t new_slot_index = hash_value & (new_slots - 1); + + // Linear probing to find an empty slot in new_data + while (true) { + slot_ref_t new_slot = slot_ref(new_data, new_slot_index); + if (!(new_slot.header.populated & new_slot.mask) || (new_slot.header.deleted & new_slot.mask)) { + populate_slot(new_slot, std::move(old_slot.element)); + new_slot.header.populated |= new_slot.mask; + break; + } + new_slot_index = (new_slot_index + 1) & (new_slots - 1); + } + } + + // Deallocate old data and update pointers and sizes + if (data_) + allocator_t{}.deallocate(data_, buckets_ * bytes_per_bucket()); + data_ = new_data; + buckets_ = new_buckets; + capacity_slots_ = new_slots; + + return true; + } + + template class equal_iterator_gt { + public: + using iterator_category = std::forward_iterator_tag; + using value_type = element_t; + using difference_type = std::ptrdiff_t; + using pointer = element_t*; + using reference = element_t&; + + equal_iterator_gt(std::size_t index, flat_hash_multi_set_gt* parent, query_at const& query, + equals_t const& equals) + : index_(index), parent_(parent), query_(query), equals_(equals) {} + + // Pre-increment + equal_iterator_gt& operator++() { + do { + index_ = (index_ + 1) & (parent_->capacity_slots_ - 1); + } while (!equals_(parent_->slot_ref(index_).element, query_) && + (parent_->slot_ref(index_).header.populated & parent_->slot_ref(index_).mask)); + return *this; + } + + equal_iterator_gt operator++(int) { + equal_iterator_gt temp = *this; + ++(*this); + return temp; + } + + reference operator*() { return parent_->slot_ref(index_).element; } + pointer operator->() { return &parent_->slot_ref(index_).element; } + bool operator!=(equal_iterator_gt const& other) const { return !(*this == other); } + bool operator==(equal_iterator_gt const& other) const { + return index_ == other.index_ && parent_ == other.parent_; + } + + private: + std::size_t index_; + flat_hash_multi_set_gt* parent_; + query_at query_; // Store the query object + equals_t equals_; // Store the equals functor + }; + + /** + * @brief Returns an iterator range of all elements matching the given query. + * + * Technically, the second iterator points to the first empty slot after a + * range of equal values and non-equal values with similar hashes. + */ + template + std::pair, equal_iterator_gt> + equal_range(query_at const& query) const noexcept { + + equals_t equals; + auto this_ptr = const_cast(this); + auto end = equal_iterator_gt(capacity_slots_, this_ptr, query, equals); + if (!capacity_slots_) + return {end, end}; + + hash_t hasher; + std::size_t hash_value = hasher(query); + std::size_t first_equal_index = hash_value & (capacity_slots_ - 1); + std::size_t const start_index = first_equal_index; + + // Linear probing to find the first equal element + do { + slot_ref_t slot = slot_ref(first_equal_index); + if (slot.header.populated & ~slot.header.deleted & slot.mask) { + if (equals(slot.element, query)) + break; + } + // Stop if we find an empty slot + else if (~slot.header.populated & slot.mask) + return {end, end}; + + // Move to the next slot + first_equal_index = (first_equal_index + 1) & (capacity_slots_ - 1); + } while (first_equal_index != start_index); + + // If no matching element was found, return end iterators + if (first_equal_index == capacity_slots_) + return {end, end}; + + // Start from the first matching element and find the end of the populated range + std::size_t first_empty_index = first_equal_index; + do { + first_empty_index = (first_empty_index + 1) & (capacity_slots_ - 1); + slot_ref_t slot = slot_ref(first_empty_index); + + // If we find an empty slot, this is our end + if (~slot.header.populated & slot.mask) + break; + } while (first_empty_index != start_index); + + return {equal_iterator_gt(first_equal_index, this_ptr, query, equals), + equal_iterator_gt(first_empty_index, this_ptr, query, equals)}; + } + + template bool pop_first(similar_at&& query, element_t& popped_value) noexcept { + + if (!capacity_slots_) + return false; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + std::size_t start_index = slot_index; // To detect loop in probing + + // Linear probing to find the first match + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) { + // Found a match, mark as deleted + slot.header.deleted |= slot.mask; + --populated_slots_; + popped_value = slot.element; + return true; // Successfully removed + } + } else { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + } while (slot_index != start_index); + + return false; // No match found + } + + template std::size_t erase(similar_at&& query) noexcept { + + if (!capacity_slots_) + return 0; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + std::size_t const start_index = slot_index; // To detect loop in probing + std::size_t count = 0; // Count of elements removed + + // Linear probing to find all matches + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) { + // Found a match, mark as deleted + slot.header.deleted |= slot.mask; + --populated_slots_; + ++count; // Increment count of elements removed + } + } else { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + } while (slot_index != start_index); + + return count; // Return the number of elements removed + } + + template element_t const* find(similar_at&& query) const noexcept { + + if (!capacity_slots_) + return nullptr; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + std::size_t start_index = slot_index; // To detect loop in probing + + // Linear probing to find the first match + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) + return &slot.element; // Found a match, return pointer to the element + } else { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); // Assuming capacity_slots_ is a power of 2 + } while (slot_index != start_index); + + return nullptr; // No match found + } + + element_t const* end() const noexcept { return nullptr; } + + template void for_each(func_at&& func) const { + for (std::size_t bucket_index = 0; bucket_index < buckets_; ++bucket_index) { + auto bucket_pointer = data_ + bytes_per_bucket() * bucket_index; + bucket_header_t& header = *reinterpret_cast(bucket_pointer); + std::uint64_t populated = header.populated; + std::uint64_t deleted = header.deleted; + + // Iterate through slots in the bucket + for (std::size_t in_bucket_index = 0; in_bucket_index < slots_per_bucket(); ++in_bucket_index) { + std::uint64_t mask = std::uint64_t(1ull) << in_bucket_index; + + // Check if the slot is populated and not deleted + if ((populated & ~deleted) & mask) { + auto slot_pointer = bucket_pointer + sizeof(bucket_header_t) + sizeof(element_t) * in_bucket_index; + element_t const& element = *reinterpret_cast(slot_pointer); + func(element); + } + } + } + } + + template std::size_t count(similar_at&& query) const noexcept { + + if (!capacity_slots_) + return 0; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); + std::size_t start_index = slot_index; // To detect loop in probing + std::size_t count = 0; + + // Linear probing to find the range + do { + slot_ref_t slot = slot_ref(slot_index); + if ((slot.header.populated & slot.mask) && (~slot.header.deleted & slot.mask)) { + if (equals(slot.element, query)) + ++count; + } else if (~slot.header.populated & slot.mask) { + // Stop if we find an empty slot + break; + } + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); + } while (slot_index != start_index); + + return count; + } + + template bool contains(similar_at&& query) const noexcept { + + if (!capacity_slots_) + return false; + + hash_t hasher; + equals_t equals; + std::size_t hash_value = hasher(query); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); + std::size_t start_index = slot_index; // To detect loop in probing + + // Linear probing to find the first match + do { + slot_ref_t slot = slot_ref(slot_index); + if (slot.header.populated & slot.mask) { + if ((~slot.header.deleted & slot.mask) && equals(slot.element, query)) + return true; // Found a match, exit early + } else + // Stop if we find an empty slot + break; + + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); + } while (slot_index != start_index); + + return false; // No match found + } + + void reserve(std::size_t capacity) { + if (!try_reserve(capacity)) + usearch_raise_runtime_error("failed to reserve memory"); + } + + bool try_emplace(element_t const& element) noexcept { + // Check if we need to resize + if (populated_slots_ * 3u >= capacity_slots_ * 2u) + if (!try_reserve(populated_slots_ + 1)) + return false; + + hash_t hasher; + std::size_t hash_value = hasher(element); + std::size_t slot_index = hash_value & (capacity_slots_ - 1); + + // Linear probing + while (true) { + slot_ref_t slot = slot_ref(slot_index); + if ((~slot.header.populated & slot.mask) | (slot.header.deleted & slot.mask)) { + // Found an empty or deleted slot + populate_slot(slot, element); + ++populated_slots_; + return true; + } + // Move to the next slot + slot_index = (slot_index + 1) & (capacity_slots_ - 1); + } + } +}; + +} // namespace usearch +} // namespace unum diff --git a/zig/usearch/include/lib.cpp b/zig/usearch/include/lib.cpp new file mode 100644 index 000000000..ff447263a --- /dev/null +++ b/zig/usearch/include/lib.cpp @@ -0,0 +1,507 @@ +#include + +#include "index_dense.hpp" + +extern "C" { +#include "usearch.h" +} + +// Check if NDEBUG is defined to determine if it's a release build +#ifdef NDEBUG +#define USEARCH_ASSERT(expression) (void)(expression) +#else +#define USEARCH_ASSERT(expression) assert(expression) +#endif + +using namespace unum::usearch; +using namespace unum; + +using add_result_t = typename index_dense_t::add_result_t; +using search_result_t = typename index_dense_t::search_result_t; +using labeling_result_t = typename index_dense_t::labeling_result_t; + +static_assert(std::is_same::value, "Type mismatch between C and C++"); +static_assert(std::is_same::value, "Type mismatch between C and C++"); + +metric_kind_t metric_kind_to_cpp(usearch_metric_kind_t kind) { + switch (kind) { + case usearch_metric_ip_k: return metric_kind_t::ip_k; + case usearch_metric_l2sq_k: return metric_kind_t::l2sq_k; + case usearch_metric_cos_k: return metric_kind_t::cos_k; + case usearch_metric_haversine_k: return metric_kind_t::haversine_k; + case usearch_metric_divergence_k: return metric_kind_t::divergence_k; + case usearch_metric_pearson_k: return metric_kind_t::pearson_k; + case usearch_metric_jaccard_k: return metric_kind_t::jaccard_k; + case usearch_metric_hamming_k: return metric_kind_t::hamming_k; + case usearch_metric_tanimoto_k: return metric_kind_t::tanimoto_k; + case usearch_metric_sorensen_k: return metric_kind_t::sorensen_k; + default: return metric_kind_t::unknown_k; + } +} + +usearch_metric_kind_t metric_kind_to_c(metric_kind_t kind) { + switch (kind) { + case metric_kind_t::ip_k: return usearch_metric_ip_k; + case metric_kind_t::l2sq_k: return usearch_metric_l2sq_k; + case metric_kind_t::cos_k: return usearch_metric_cos_k; + case metric_kind_t::haversine_k: return usearch_metric_haversine_k; + case metric_kind_t::divergence_k: return usearch_metric_divergence_k; + case metric_kind_t::pearson_k: return usearch_metric_pearson_k; + case metric_kind_t::jaccard_k: return usearch_metric_jaccard_k; + case metric_kind_t::hamming_k: return usearch_metric_hamming_k; + case metric_kind_t::tanimoto_k: return usearch_metric_tanimoto_k; + case metric_kind_t::sorensen_k: return usearch_metric_sorensen_k; + default: return usearch_metric_unknown_k; + } +} +scalar_kind_t scalar_kind_to_cpp(usearch_scalar_kind_t kind) { + switch (kind) { + case usearch_scalar_f32_k: return scalar_kind_t::f32_k; + case usearch_scalar_f64_k: return scalar_kind_t::f64_k; + case usearch_scalar_f16_k: return scalar_kind_t::f16_k; + case usearch_scalar_bf16_k: return scalar_kind_t::bf16_k; + case usearch_scalar_i8_k: return scalar_kind_t::i8_k; + case usearch_scalar_b1_k: return scalar_kind_t::b1x8_k; + default: return scalar_kind_t::unknown_k; + } +} + +usearch_scalar_kind_t scalar_kind_to_c(scalar_kind_t kind) { + switch (kind) { + case scalar_kind_t::f32_k: return usearch_scalar_f32_k; + case scalar_kind_t::f64_k: return usearch_scalar_f64_k; + case scalar_kind_t::f16_k: return usearch_scalar_f16_k; + case scalar_kind_t::bf16_k: return usearch_scalar_bf16_k; + case scalar_kind_t::i8_k: return usearch_scalar_i8_k; + case scalar_kind_t::b1x8_k: return usearch_scalar_b1_k; + default: return usearch_scalar_unknown_k; + } +} + +add_result_t add_(index_dense_t* index, usearch_key_t key, void const* vector, scalar_kind_t kind) { + switch (kind) { + case scalar_kind_t::f32_k: return index->add(key, (f32_t const*)vector); + case scalar_kind_t::f64_k: return index->add(key, (f64_t const*)vector); + case scalar_kind_t::f16_k: return index->add(key, (f16_t const*)vector); + case scalar_kind_t::bf16_k: return index->add(key, (bf16_t const*)vector); + case scalar_kind_t::i8_k: return index->add(key, (i8_t const*)vector); + case scalar_kind_t::b1x8_k: return index->add(key, (b1x8_t const*)vector); + default: return add_result_t{}.failed("Unknown scalar kind!"); + } +} + +std::size_t get_(index_dense_t* index, usearch_key_t key, size_t count, void* vector, scalar_kind_t kind) { + switch (kind) { + case scalar_kind_t::f32_k: return index->get(key, (f32_t*)vector, count); + case scalar_kind_t::f64_k: return index->get(key, (f64_t*)vector, count); + case scalar_kind_t::f16_k: return index->get(key, (f16_t*)vector, count); + case scalar_kind_t::bf16_k: return index->get(key, (bf16_t*)vector, count); + case scalar_kind_t::i8_k: return index->get(key, (i8_t*)vector, count); + case scalar_kind_t::b1x8_k: return index->get(key, (b1x8_t*)vector, count); + default: return search_result_t(*index).failed("Unknown scalar kind!"); + } +} + +template +search_result_t search_(index_dense_t* index, void const* vector, scalar_kind_t kind, size_t n, + predicate_at&& predicate = predicate_at{}) { + switch (kind) { + case scalar_kind_t::f32_k: + return index->filtered_search((f32_t const*)vector, n, std::forward(predicate)); + case scalar_kind_t::f64_k: + return index->filtered_search((f64_t const*)vector, n, std::forward(predicate)); + case scalar_kind_t::f16_k: + return index->filtered_search((f16_t const*)vector, n, std::forward(predicate)); + case scalar_kind_t::bf16_k: + return index->filtered_search((bf16_t const*)vector, n, std::forward(predicate)); + case scalar_kind_t::i8_k: + return index->filtered_search((i8_t const*)vector, n, std::forward(predicate)); + case scalar_kind_t::b1x8_k: + return index->filtered_search((b1x8_t const*)vector, n, std::forward(predicate)); + default: return search_result_t(*index).failed("Unknown scalar kind!"); + } +} + +extern "C" { + +USEARCH_EXPORT char const* usearch_version(void) { + int major = USEARCH_VERSION_MAJOR; + int minor = USEARCH_VERSION_MINOR; + int patch = USEARCH_VERSION_PATCH; + static char version[32]; + std::snprintf(version, sizeof(version), "%d.%d.%d", major, minor, patch); + return version; +} + +USEARCH_EXPORT usearch_index_t usearch_init(usearch_init_options_t* options, usearch_error_t* error) { + + USEARCH_ASSERT(error && "Missing arguments"); + + // The user may want to initialize from a file. + // In that case he may pass NULL options, and we will try to load the metadata from the file. + if (!options) { + index_dense_t* result_ptr = new index_dense_t(); + if (!result_ptr) + *error = "Out of memory!"; + return result_ptr; + } + + index_dense_config_t config; + config.connectivity = options->connectivity; + config.expansion_add = options->expansion_add; + config.expansion_search = options->expansion_search; + config.multi = options->multi; + config.enable_key_lookups = 1; + + metric_kind_t metric_kind = metric_kind_to_cpp(options->metric_kind); + scalar_kind_t scalar_kind = scalar_kind_to_cpp(options->quantization); + metric_punned_t metric = // + !options->metric ? metric_punned_t::builtin(options->dimensions, metric_kind, scalar_kind) + : metric_punned_t::stateless(options->dimensions, // + reinterpret_cast(options->metric), // + metric_punned_signature_t::array_array_k, // + metric_kind, scalar_kind); + if (metric.missing()) { + *error = "Unknown metric kind!"; + return NULL; + } + + using state_result_t = typename index_dense_t::state_result_t; + state_result_t state = index_dense_t::make(metric, config); + if (!state) + *error = state.error.release(); + index_dense_t* result_ptr = new index_dense_t(std::move(state.index)); + if (!result_ptr) + *error = "Out of memory!"; + + // Let's immediately make it usable by reserving enough threads for this machine: + if (!result_ptr->try_reserve(index_limits_t())) + *error = "Out of memory when preparing contexts!"; + + return result_ptr; +} + +USEARCH_EXPORT void usearch_free(usearch_index_t index, usearch_error_t*) { + delete reinterpret_cast(index); +} + +USEARCH_EXPORT size_t usearch_serialized_length(usearch_index_t index, usearch_error_t*) { + USEARCH_ASSERT(index && "Missing arguments"); + return reinterpret_cast(index)->serialized_length(); +} + +USEARCH_EXPORT void usearch_save(usearch_index_t index, char const* path, usearch_error_t* error) { + + USEARCH_ASSERT(index && path && error && "Missing arguments"); + serialization_result_t result = reinterpret_cast(index)->save(path); + if (!result) + *error = result.error.release(); +} + +USEARCH_EXPORT void usearch_load(usearch_index_t index, char const* path, usearch_error_t* error) { + + USEARCH_ASSERT(index && path && error && "Missing arguments"); + serialization_result_t result = reinterpret_cast(index)->load(path); + if (!result) + *error = result.error.release(); +} + +USEARCH_EXPORT void usearch_view(usearch_index_t index, char const* path, usearch_error_t* error) { + + USEARCH_ASSERT(index && path && error && "Missing arguments"); + serialization_result_t result = reinterpret_cast(index)->view(path); + if (!result) + *error = result.error.release(); +} + +USEARCH_EXPORT void usearch_metadata(char const* path, usearch_init_options_t* options, usearch_error_t* error) { + + USEARCH_ASSERT(path && options && error && "Missing arguments"); + index_dense_metadata_result_t result = index_dense_metadata_from_path(path); + if (!result) + *error = result.error.release(); + + options->metric_kind = metric_kind_to_c(result.head.kind_metric); + options->quantization = scalar_kind_to_c(result.head.kind_scalar); + options->dimensions = result.head.dimensions; + options->multi = result.head.multi; + + options->connectivity = 0; + options->expansion_add = 0; + options->expansion_search = 0; + options->metric = NULL; +} + +USEARCH_EXPORT void usearch_save_buffer(usearch_index_t index, void* buffer, size_t length, usearch_error_t* error) { + + USEARCH_ASSERT(index && buffer && length && error && "Missing arguments"); + memory_mapped_file_t memory_map((byte_t*)buffer, length); + serialization_result_t result = reinterpret_cast(index)->save(std::move(memory_map)); + if (!result) + *error = result.error.release(); +} + +USEARCH_EXPORT void usearch_load_buffer(usearch_index_t index, void const* buffer, size_t length, + usearch_error_t* error) { + + USEARCH_ASSERT(index && buffer && length && error && "Missing arguments"); + memory_mapped_file_t memory_map((byte_t*)buffer, length); + serialization_result_t result = reinterpret_cast(index)->load(std::move(memory_map)); + if (!result) + *error = result.error.release(); +} + +USEARCH_EXPORT void usearch_view_buffer(usearch_index_t index, void const* buffer, size_t length, + usearch_error_t* error) { + + USEARCH_ASSERT(index && buffer && length && error && "Missing arguments"); + memory_mapped_file_t memory_map((byte_t*)buffer, length); + serialization_result_t result = reinterpret_cast(index)->view(std::move(memory_map)); + if (!result) + *error = result.error.release(); +} + +USEARCH_EXPORT void usearch_metadata_buffer(void const* buffer, size_t length, usearch_init_options_t* options, + usearch_error_t* error) { + + USEARCH_ASSERT(buffer && length && options && error && "Missing arguments"); + index_dense_metadata_result_t result = + index_dense_metadata_from_buffer(memory_mapped_file_t((byte_t*)(buffer), length)); + if (!result) + *error = result.error.release(); + + options->metric_kind = metric_kind_to_c(result.head.kind_metric); + options->quantization = scalar_kind_to_c(result.head.kind_scalar); + options->dimensions = result.head.dimensions; + options->multi = result.head.multi; + + options->connectivity = 0; + options->expansion_add = 0; + options->expansion_search = 0; + options->metric = NULL; +} + +USEARCH_EXPORT size_t usearch_size(usearch_index_t index, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + return reinterpret_cast(index)->size(); +} + +USEARCH_EXPORT size_t usearch_capacity(usearch_index_t index, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + return reinterpret_cast(index)->capacity(); +} + +USEARCH_EXPORT size_t usearch_dimensions(usearch_index_t index, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + return reinterpret_cast(index)->dimensions(); +} + +USEARCH_EXPORT size_t usearch_connectivity(usearch_index_t index, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + return reinterpret_cast(index)->connectivity(); +} + +USEARCH_EXPORT size_t usearch_expansion_add(usearch_index_t index, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + return reinterpret_cast(index)->expansion_add(); +} + +USEARCH_EXPORT size_t usearch_expansion_search(usearch_index_t index, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + return reinterpret_cast(index)->expansion_search(); +} + +USEARCH_EXPORT size_t usearch_memory_usage(usearch_index_t index, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + return reinterpret_cast(index)->memory_usage(); +} + +USEARCH_EXPORT char const* usearch_hardware_acceleration(usearch_index_t index, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + return reinterpret_cast(index)->metric().isa_name(); +} + +USEARCH_EXPORT void usearch_change_expansion_add(usearch_index_t index, size_t expansion, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + reinterpret_cast(index)->change_expansion_add(expansion); +} + +USEARCH_EXPORT void usearch_change_expansion_search(usearch_index_t index, size_t expansion, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + reinterpret_cast(index)->change_expansion_search(expansion); +} + +USEARCH_EXPORT void usearch_change_threads_add(usearch_index_t index, size_t threads, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + auto& index_dense = *reinterpret_cast(index); + index_limits_t limits = index_dense.limits(); + limits.threads_add = threads; + index_dense.try_reserve(limits); +} + +USEARCH_EXPORT void usearch_change_threads_search(usearch_index_t index, size_t threads, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + auto& index_dense = *reinterpret_cast(index); + index_limits_t limits = index_dense.limits(); + limits.threads_search = threads; + index_dense.try_reserve(limits); +} + +USEARCH_EXPORT void usearch_change_metric_kind(usearch_index_t index, usearch_metric_kind_t kind, + usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + auto& index_dense = *reinterpret_cast(index); + index_dense.change_metric( + metric_punned_t::builtin(index_dense.dimensions(), metric_kind_to_cpp(kind), index_dense.scalar_kind())); +} + +USEARCH_EXPORT void usearch_change_metric(usearch_index_t index, usearch_metric_t metric, void* state, + usearch_metric_kind_t kind, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + auto& index_dense = *reinterpret_cast(index); + auto metric_punned = + state ? metric_punned_t::stateful(index_dense.dimensions(), reinterpret_cast(metric), + reinterpret_cast(state), metric_kind_to_cpp(kind), + index_dense.scalar_kind()) + : metric_punned_t::stateless(index_dense.dimensions(), reinterpret_cast(metric), + metric_punned_signature_t::array_array_k, metric_kind_to_cpp(kind), + index_dense.scalar_kind()); + index_dense.change_metric(std::move(metric_punned)); +} + +USEARCH_EXPORT void usearch_reserve(usearch_index_t index, size_t capacity, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + if (!reinterpret_cast(index)->try_reserve(capacity)) + *error = "Out of memory!"; +} + +USEARCH_EXPORT void usearch_add( // + usearch_index_t index, usearch_key_t key, void const* vector, usearch_scalar_kind_t kind, // + usearch_error_t* error) { + + USEARCH_ASSERT(index && vector && error && "Missing arguments"); + add_result_t result = add_(reinterpret_cast(index), key, vector, scalar_kind_to_cpp(kind)); + if (!result) + *error = result.error.release(); +} + +USEARCH_EXPORT bool usearch_contains(usearch_index_t index, usearch_key_t key, usearch_error_t*) { + USEARCH_ASSERT(index && "Missing arguments"); + return reinterpret_cast(index)->contains(key); +} + +USEARCH_EXPORT size_t usearch_count(usearch_index_t index, usearch_key_t key, usearch_error_t*) { + USEARCH_ASSERT(index && "Missing arguments"); + return reinterpret_cast(index)->count(key); +} + +USEARCH_EXPORT size_t usearch_search( // + usearch_index_t index, void const* query, usearch_scalar_kind_t query_kind, size_t results_limit, // + usearch_key_t* found_keys, usearch_distance_t* found_distances, usearch_error_t* error) { + + USEARCH_ASSERT(index && query && error && "Missing arguments"); + search_result_t result = + search_(reinterpret_cast(index), query, scalar_kind_to_cpp(query_kind), results_limit); + if (!result) { + *error = result.error.release(); + return 0; + } + + return result.dump_to(found_keys, found_distances, results_limit); +} + +USEARCH_EXPORT size_t usearch_filtered_search( // + usearch_index_t index, // + void const* query, usearch_scalar_kind_t query_kind, size_t results_limit, // + int (*filter)(usearch_key_t key, void* filter_state), void* filter_state, // + usearch_key_t* found_keys, usearch_distance_t* found_distances, usearch_error_t* error) { + + USEARCH_ASSERT(index && query && filter && error && "Missing arguments"); + search_result_t result = + search_(reinterpret_cast(index), query, scalar_kind_to_cpp(query_kind), results_limit, + [=](usearch_key_t key) noexcept { return filter(key, filter_state); }); + if (!result) { + *error = result.error.release(); + return 0; + } + + return result.dump_to(found_keys, found_distances, results_limit); +} + +USEARCH_EXPORT size_t usearch_get( // + usearch_index_t index, usearch_key_t key, size_t count, // + void* vectors, usearch_scalar_kind_t kind, usearch_error_t*) { + + USEARCH_ASSERT(index && vectors); + return get_(reinterpret_cast(index), key, count, vectors, scalar_kind_to_cpp(kind)); +} + +USEARCH_EXPORT size_t usearch_remove(usearch_index_t index, usearch_key_t key, usearch_error_t* error) { + + USEARCH_ASSERT(index && error && "Missing arguments"); + labeling_result_t result = reinterpret_cast(index)->remove(key); + if (!result) + *error = result.error.release(); + return result.completed; +} + +USEARCH_EXPORT size_t usearch_rename( // + usearch_index_t index, usearch_key_t from, usearch_key_t to, usearch_error_t* error) { + + USEARCH_ASSERT(index && error && "Missing arguments"); + labeling_result_t result = reinterpret_cast(index)->rename(from, to); + if (!result) + *error = result.error.release(); + return result.completed; +} + +USEARCH_EXPORT usearch_distance_t usearch_distance( // + void const* vector_first, void const* vector_second, // + usearch_scalar_kind_t scalar_kind, size_t dimensions, // + usearch_metric_kind_t metric_kind, usearch_error_t* error) { + + (void)error; + metric_punned_t metric(dimensions, metric_kind_to_cpp(metric_kind), scalar_kind_to_cpp(scalar_kind)); + return metric((byte_t const*)vector_first, (byte_t const*)vector_second); +} + +USEARCH_EXPORT void usearch_exact_search( // + void const* dataset, size_t dataset_count, size_t dataset_stride, // + void const* queries, size_t queries_count, size_t queries_stride, // + usearch_scalar_kind_t scalar_kind, size_t dimensions, // + usearch_metric_kind_t metric_kind, size_t count, size_t threads, // + usearch_key_t* keys, size_t keys_stride, // + usearch_distance_t* distances, size_t distances_stride, // + usearch_error_t* error) { + + USEARCH_ASSERT(dataset && queries && keys && distances && error && "Missing arguments"); + + metric_punned_t metric(dimensions, metric_kind_to_cpp(metric_kind), scalar_kind_to_cpp(scalar_kind)); + executor_default_t executor(threads); + static exact_search_t search; + exact_search_results_t result = search( // + (byte_t const*)dataset, dataset_count, dataset_stride, // + (byte_t const*)queries, queries_count, queries_stride, // + count, metric); + + if (!result) { + *error = "Out of memory, allocating a temporary buffer for batch results"; + return; + } + + // Export results into the output buffer + for (std::size_t query_idx = 0; query_idx != queries_count; ++query_idx) { + auto query_result = result.at(query_idx); + auto query_keys = (usearch_key_t*)((byte_t*)keys + query_idx * keys_stride); + auto query_distances = (usearch_distance_t*)((byte_t*)distances + query_idx * distances_stride); + for (std::size_t i = 0; i != count; ++i) + query_keys[i] = static_cast(query_result[i].offset), + query_distances[i] = static_cast(query_result[i].distance); + } +} + +USEARCH_EXPORT void usearch_clear(usearch_index_t index, usearch_error_t* error) { + USEARCH_ASSERT(index && error && "Missing arguments"); + reinterpret_cast(index)->clear(); +} +} diff --git a/zig/usearch/include/usearch.h b/zig/usearch/include/usearch.h new file mode 100644 index 000000000..c61164590 --- /dev/null +++ b/zig/usearch/include/usearch.h @@ -0,0 +1,487 @@ +#ifndef UNUM_USEARCH_H +#define UNUM_USEARCH_H + +#include // `bool` +#include // `size_t` +#include // `uint64_t` + +#ifdef __cplusplus +extern "C" { +#endif + +#ifndef USEARCH_EXPORT +#if defined(_WIN32) && !defined(__MINGW32__) +#define USEARCH_EXPORT __declspec(dllexport) +#else +#define USEARCH_EXPORT +#endif +#endif + +USEARCH_EXPORT typedef void* usearch_index_t; +USEARCH_EXPORT typedef uint64_t usearch_key_t; +USEARCH_EXPORT typedef float usearch_distance_t; + +/** + * @brief Pointer to a null-terminated error message. + * Returned error messages @b don't need to be deallocated. + */ +USEARCH_EXPORT typedef char const* usearch_error_t; + +/** + * @brief Type-punned callback for "metrics" or "distance functions", + * that accepts pointers to two vectors and measures their @b dis-similarity. + */ +USEARCH_EXPORT typedef usearch_distance_t (*usearch_metric_t)(void const*, void const*); + +/** + * @brief Enumerator for the most common kinds of `usearch_metric_t`. + * Those are supported out of the box, with SIMD-optimizations for most common hardware. + */ +USEARCH_EXPORT typedef enum usearch_metric_kind_t { + usearch_metric_unknown_k = 0, + usearch_metric_cos_k = 1, + usearch_metric_ip_k = 2, + usearch_metric_l2sq_k = 3, + usearch_metric_haversine_k = 4, + usearch_metric_divergence_k = 5, + usearch_metric_pearson_k = 6, + usearch_metric_jaccard_k = 7, + usearch_metric_hamming_k = 8, + usearch_metric_tanimoto_k = 9, + usearch_metric_sorensen_k = 10, +} usearch_metric_kind_t; + +USEARCH_EXPORT typedef enum usearch_scalar_kind_t { + usearch_scalar_unknown_k = 0, + usearch_scalar_f32_k = 1, + usearch_scalar_f64_k = 2, + usearch_scalar_f16_k = 3, + usearch_scalar_i8_k = 4, + usearch_scalar_b1_k = 5, + usearch_scalar_bf16_k = 6, +} usearch_scalar_kind_t; + +USEARCH_EXPORT typedef struct usearch_init_options_t { + /** + * @brief The metric kind used for distance calculation between vectors. + */ + usearch_metric_kind_t metric_kind; + /** + * @brief The @b optional custom distance metric function used for distance calculation between vectors. + * If the `metric_kind` is set to `usearch_metric_unknown_k`, this function pointer mustn't be `NULL`. + */ + usearch_metric_t metric; + /** + * @brief The scalar kind used for quantization of vector data during indexing. + * In most cases, on modern hardware, it's recommended to use half-precision floating-point numbers. + * When quantization is enabled, the "get"-like functions won't be able to recover the original data, + * so you may want to replicate the original vectors elsewhere. + * + * Quantizing to integers is also possible, but it's important to note that it's only valid for cosine-like + * metrics. As part of the quantization process, the vectors are normalized to unit length and later scaled + * to @b [-127,127] range to occupy the full 8-bit range. + * + * Quantizing to 1-bit booleans is also possible, but it's only valid for binary metrics like Jaccard, Hamming, + * etc. As part of the quantization process, the scalar components greater than zero are set to `true`, and the + * rest to `false`. + */ + usearch_scalar_kind_t quantization; + /** + * @brief The number of dimensions in the vectors to be indexed. + * Must be defined for most metrics, but can be avoided for `usearch_metric_haversine_k`. + */ + size_t dimensions; + /** + * @brief The @b optional connectivity parameter that limits connections-per-node in graph. + */ + size_t connectivity; + /** + * @brief The @b optional expansion factor used for index construction when adding vectors. + */ + size_t expansion_add; + /** + * @brief The @b optional expansion factor used for index construction during search operations. + */ + size_t expansion_search; + /** + * @brief When set allows multiple vectors to map to the same key. + */ + bool multi; +} usearch_init_options_t; + +/** + * @brief Retrieves the version of the library. + * @return The version of the library. + */ +USEARCH_EXPORT char const* usearch_version(void); + +/** + * @brief Initializes a new instance of the index. + * @param options Pointer to the `usearch_init_options_t` structure containing initialization options. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return A handle to the initialized USearch index, or `NULL` on failure. + */ +USEARCH_EXPORT usearch_index_t usearch_init(usearch_init_options_t* options, usearch_error_t* error); + +/** + * @brief Frees the resources associated with the index. + * @param[inout] index The handle to the USearch index to be freed. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_free(usearch_index_t index, usearch_error_t* error); + +/** + * @brief Reports the memory usage of the index. + * @param[in] index The handle to the USearch index to be queried. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return Number of bytes used by the index. + */ +USEARCH_EXPORT size_t usearch_memory_usage(usearch_index_t index, usearch_error_t* error); + +/** + * @brief Reports the SIMD capabilities used by the index on the current CPU. + * @param[in] index The handle to the USearch index to be queried. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return The codename of the SIMD instruction set used by the index. + */ +USEARCH_EXPORT char const* usearch_hardware_acceleration(usearch_index_t index, usearch_error_t* error); + +/** + * @brief Reports expected file size after serialization. + * @param[in] index The handle to the USearch index to be serialized. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT size_t usearch_serialized_length(usearch_index_t index, usearch_error_t* error); + +/** + * @brief Saves the index to a file. + * @param[in] index The handle to the USearch index to be serialized. + * @param[in] path The file path where the index will be saved. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_save(usearch_index_t index, char const* path, usearch_error_t* error); + +/** + * @brief Loads the index from a file. + * @param[inout] index The handle to the USearch index to be populated from path. + * @param[in] path The file path from where the index will be loaded. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_load(usearch_index_t index, char const* path, usearch_error_t* error); + +/** + * @brief Creates a view of the index from a file without copying it into memory. + * @param[inout] index The handle to the USearch index to be populated with a file view. + * @param[in] path The file path from where the view will be created. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_view(usearch_index_t index, char const* path, usearch_error_t* error); + +/** + * @brief Loads index metadata from a file. + * @param[in] path The file path from where the index will be loaded. + * @param[out] options Pointer to the `usearch_init_options_t` structure to be populated. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_metadata(char const* path, usearch_init_options_t* options, usearch_error_t* error); + +/** + * @brief Saves the index to an in-memory buffer. + * @param[in] index The handle to the USearch index to be serialized. + * @param[in] buffer The in-memory continuous buffer where the index will be saved. + * @param[in] length The length of the buffer in bytes. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_save_buffer(usearch_index_t index, void* buffer, size_t length, usearch_error_t* error); + +/** + * @brief Loads the index from an in-memory buffer. + * @param[inout] index The handle to the USearch index to be populated from buffer. + * @param[in] buffer The in-memory continuous buffer from where the index will be loaded. + * @param[in] length The length of the buffer in bytes. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_load_buffer(usearch_index_t index, void const* buffer, size_t length, + usearch_error_t* error); + +/** + * @brief Creates a view of the index from an in-memory buffer without copying it into memory. + * @param[inout] index The handle to the USearch index to be populated with a buffer view. + * @param[in] buffer The in-memory continuous buffer from where the view will be created. + * @param[in] length The length of the buffer in bytes. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_view_buffer(usearch_index_t index, void const* buffer, size_t length, + usearch_error_t* error); + +/** + * @brief Loads index metadata from an in-memory buffer. + * @param[in] buffer The in-memory continuous buffer from where the view will be created. + * @param[out] options Pointer to the `usearch_init_options_t` structure to be populated. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_metadata_buffer(void const* buffer, size_t length, usearch_init_options_t* options, + usearch_error_t* error); + +/** + * @brief Reports the current size (number of vectors) of the index. + * @param[in] index The handle to the USearch index to be queried. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT size_t usearch_size(usearch_index_t index, usearch_error_t* error); + +/** + * @brief Reports the current capacity (number of vectors) of the index. + * @param[in] index The handle to the USearch index to be queried. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT size_t usearch_capacity(usearch_index_t index, usearch_error_t* error); + +/** + * @brief Reports the current dimensions of the vectors in the index. + * @param[in] index The handle to the USearch index to be queried. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT size_t usearch_dimensions(usearch_index_t index, usearch_error_t* error); + +/** + * @brief Reports the current connectivity of the index. + * @param[in] index The handle to the USearch index to be queried. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT size_t usearch_connectivity(usearch_index_t index, usearch_error_t* error); + +/** + * @brief Reserves memory for a specified number of incoming vectors. + * @param[inout] index The handle to the USearch index to be resized. + * @param[in] capacity The desired total capacity including current size. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_reserve(usearch_index_t index, size_t capacity, usearch_error_t* error); + +/** + * @brief Retrieves the expansion value used during index creation. + * @param[in] index The handle to the USearch index to be queried. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return The expansion value used during index creation. + */ +USEARCH_EXPORT size_t usearch_expansion_add(usearch_index_t index, usearch_error_t* error); + +/** + * @brief Retrieves the expansion value used during search. + * @param[in] index The handle to the USearch index to be queried. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return The expansion value used during search. + */ +USEARCH_EXPORT size_t usearch_expansion_search(usearch_index_t index, usearch_error_t* error); + +/** + * @brief Updates the expansion value used during index creation. Rarely used. + * @param[in] index The handle to the USearch index to be queried. + * @param[in] expansion The new expansion value. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_change_expansion_add(usearch_index_t index, size_t expansion, usearch_error_t* error); + +/** + * @brief Updates the expansion value used during search. Rarely used. + * @param[in] index The handle to the USearch index to be queried. + * @param[in] expansion The new expansion value. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_change_expansion_search(usearch_index_t index, size_t expansion, usearch_error_t* error); + +/** + * @brief Updates the number of threads that would be used to construct the index. Rarely used. + * @param[in] index The handle to the USearch index to be queried. + * @param[in] threads The new limit for the number of concurrent threads. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_change_threads_add(usearch_index_t index, size_t threads, usearch_error_t* error); + +/** + * @brief Updates the number of threads that will be performing concurrent traversals. Rarely used. + * @param[in] index The handle to the USearch index to be queried. + * @param[in] threads The new limit for the number of concurrent threads. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_change_threads_search(usearch_index_t index, size_t threads, usearch_error_t* error); + +/** + * @brief Updates the metric kind used for distance calculation between vectors. + * @param[in] index The handle to the USearch index to be queried. + * @param[in] kind The metric kind used for distance calculation between vectors. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_change_metric_kind(usearch_index_t index, usearch_metric_kind_t kind, + usearch_error_t* error); + +/** + * @brief Updates the custom metric function used for distance calculation between vectors. + * @param[in] index The handle to the USearch index to be queried. + * @param[in] metric The custom metric function used for distance calculation between vectors. + * @param[in] state The @b optional state pointer to be passed to the custom metric function. + * @param[in] kind The metric kind used for distance calculation between vectors. Needed for serialization. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_change_metric(usearch_index_t index, usearch_metric_t metric, void* state, + usearch_metric_kind_t kind, usearch_error_t* error); + +/** + * @brief Adds a vector with a key to the index. + * @param[inout] index The handle to the USearch index to be populated. + * @param[in] key The key associated with the vector. + * @param[in] vector Pointer to the vector data. + * @param[in] vector_kind The scalar type used in the vector data. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_add( // + usearch_index_t index, usearch_key_t key, // + void const* vector, usearch_scalar_kind_t vector_kind, usearch_error_t* error); + +/** + * @brief Checks if the index contains a vector with a specific key. + * @param[in] index The handle to the USearch index to be queried. + * @param[in] key The key to be checked. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return `true` if the index contains the vector with the given key, `false` otherwise. + */ +USEARCH_EXPORT bool usearch_contains(usearch_index_t index, usearch_key_t key, usearch_error_t* error); + +/** + * @brief Counts the number of entries in the index under a specific key. + * @param[in] index The handle to the USearch index to be queried. + * @param[in] key The key to be checked. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return Number of vectors found under that key. + */ +USEARCH_EXPORT size_t usearch_count(usearch_index_t index, usearch_key_t key, usearch_error_t* error); + +/** + * @brief Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to query. + * @param[in] index The handle to the USearch index to be queried. + * @param[in] query_vector Pointer to the query vector data. + * @param[in] query_kind The scalar type used in the query vector data. + * @param[in] count Upper bound on the number of neighbors to search, the "k" in "kANN". + * @param[out] keys Output buffer for up to `count` nearest neighbors keys. + * @param[out] distances Output buffer for up to `count` distances to nearest neighbors. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return Number of found matches. + */ +USEARCH_EXPORT size_t usearch_search( // + usearch_index_t index, // + void const* query_vector, usearch_scalar_kind_t query_kind, size_t count, // + usearch_key_t* keys, usearch_distance_t* distances, usearch_error_t* error); + +/** + * @brief Performs k-Approximate Nearest Neighbors (kANN) Search for closest vectors to query, + * predicated on a custom function that returns `true` for vectors to be included. + * + * @param[in] index The handle to the USearch index to be queried. + * @param[in] query_vector Pointer to the query vector data. + * @param[in] query_kind The scalar type used in the query vector data. + * @param[in] count Upper bound on the number of neighbors to search, the "k" in "kANN". + * @param[in] filter The custom filter function that returns `true` for vectors to be included. + * @param[in] filter_state The @b optional state pointer to be passed to the custom filter function. + * @param[out] keys Output buffer for up to `count` nearest neighbors keys. + * @param[out] distances Output buffer for up to `count` distances to nearest neighbors. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return Number of found matches. + */ +USEARCH_EXPORT size_t usearch_filtered_search( // + usearch_index_t index, // + void const* query_vector, usearch_scalar_kind_t query_kind, size_t count, // + int (*filter)(usearch_key_t key, void* filter_state), void* filter_state, // + usearch_key_t* keys, usearch_distance_t* distances, usearch_error_t* error); + +/** + * @brief Retrieves the vector associated with the given key from the index. + * @param[in] index The handle to the USearch index to be queried. + * @param[in] key The key of the vector to retrieve. + * @param[out] vector Pointer to the memory where the vector data will be copied. + * @param[in] count Number of vectors that can be fitted into `vector` for multi-vector entries. + * @param[in] vector_kind The scalar type used in the vector data. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return Number of vectors found under that name and exported to `vector`. + */ +USEARCH_EXPORT size_t usearch_get( // + usearch_index_t index, usearch_key_t key, size_t count, // + void* vector, usearch_scalar_kind_t vector_kind, usearch_error_t* error); + +/** + * @brief Removes the vector associated with the given key from the index. + * @param[inout] index The handle to the USearch index to be modified. + * @param[in] key The key of the vector to be removed. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return Number of vectors found under that name and dropped from the index. + */ +USEARCH_EXPORT size_t usearch_remove(usearch_index_t index, usearch_key_t key, usearch_error_t* error); + +/** + * @brief Renames the vector to map to a different key. + * @param[inout] index The handle to the USearch index to be modified. + * @param[in] from The key of the vector to be renamed. + * @param[in] to New key for found entry. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return Number of vectors found under that name and renamed. + */ +USEARCH_EXPORT size_t usearch_rename(usearch_index_t index, usearch_key_t from, usearch_key_t to, + usearch_error_t* error); + +/** + * @brief Computes the distance between two equi-dimensional vectors. + * @param[in] vector_first The first vector for comparison. + * @param[in] vector_second The second vector for comparison. + * @param[in] scalar_kind The scalar type used in the vectors. + * @param[in] dimensions The number of dimensions in each vector. + * @param[in] metric_kind The metric kind used for distance calculation between vectors. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + * @return Distance between given vectors. + */ +USEARCH_EXPORT usearch_distance_t usearch_distance( // + void const* vector_first, void const* vector_second, // + usearch_scalar_kind_t scalar_kind, size_t dimensions, // + usearch_metric_kind_t metric_kind, usearch_error_t* error); + +/** + * @brief Multi-threaded many-to-many exact nearest neighbors search for equi-dimensional vectors. + * @param[in] dataset Pointer to the first scalar of the dataset matrix. + * @param[in] queries Pointer to the first scalar of the queries matrix. + * @param[in] dataset_size Number of vectors in the `dataset`. + * @param[in] queries_size Number of vectors in the `queries` set. + * @param[in] dataset_stride Number of bytes between starts of consecutive vectors in `dataset`. + * @param[in] queries_stride Number of bytes between starts of consecutive vectors in `queries`. + * @param[in] scalar_kind The scalar type used in the vectors. + * @param[in] dimensions The number of dimensions in each vector. + * @param[in] metric_kind The metric kind used for distance calculation between vectors. + * @param[in] count Upper bound on the number of neighbors to search, the "k" in "kANN". + * @param[in] threads Upper bound for the number of CPU threads to use. + * @param[out] keys Output matrix for `queries_size * count` nearest neighbors keys. Each row of the + * matrix must be contiguous in memory, but different rows can be separated by `keys_stride` bytes. + * @param[in] keys_stride Number of bytes between starts of consecutive rows od scalars in `keys`. + * @param[out] distances Output matrix for `queries_size * count` distances to nearest neighbors. Each row of the + * matrix must be contiguous in memory, but different rows can be separated by `keys_stride` bytes. + * @param[in] distances_stride Number of bytes between starts of consecutive rows od scalars in `distances`. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_exact_search( // + void const* dataset, size_t dataset_size, size_t dataset_stride, // + void const* queries, size_t queries_size, size_t queries_stride, // + usearch_scalar_kind_t scalar_kind, size_t dimensions, // + usearch_metric_kind_t metric_kind, size_t count, size_t threads, // + usearch_key_t* keys, size_t keys_stride, // + usearch_distance_t* distances, size_t distances_stride, // + usearch_error_t* error); + +/** + * @brief Erases all the vectors from the index. + * @param[inout] index The handle to the USearch index to be modified. + * @param[out] error Pointer to a string where the error message will be stored, if an error occurs. + */ +USEARCH_EXPORT void usearch_clear(usearch_index_t index, usearch_error_t* error); + +#ifdef __cplusplus +} +#endif + +#endif // UNUM_USEARCH_H From a6603d2746678f0d8484d0d485eb8ff40c05901f Mon Sep 17 00:00:00 2001 From: Adib Mohsin Date: Sat, 14 Feb 2026 13:40:47 +0600 Subject: [PATCH 2/2] added antarys as an integration --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 242d66e20..220c9370d 100644 --- a/README.md +++ b/README.md @@ -542,6 +542,7 @@ index = Index(ndim=ndim, metric=CompiledMetric( - [x] Sentence-Transformers: Python [docs](https://www.sbert.net/docs/package_reference/quantization.html#sentence_transformers.quantization.semantic_search_usearch). - [x] Pathway: [Rust](https://github.com/pathwaycom/pathway). - [x] Vald: [GoLang](https://github.com/vdaas/vald). +- [x] Antarys: [Zig](https://github.com/antarys-ai/edge). ## Citations