Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ tests = [
"pytest>=9.0.2",
"pytest-repeat>=0.9.4",
]
typing = [
"mypy>=1.19.1",
"types-networkx>=3.6.1.20260321",
"types-setuptools>=82.0.0.20260210",
"types-tqdm>=4.67.3.20260303",
]

[tool.ruff]
line-length = 120
Expand All @@ -104,8 +110,23 @@ quote-style = "double"

[tool.ruff.lint]
select = [
"E9",
"F63",
"F7",
"F82",
"E9",
"F63",
"F7",
"F82",
"TC",
"T20",
"UP",
"I",
]
future-annotations = true # For TC rules

[tool.mypy]
python_version = "3.10"

[[tool.mypy.overrides]]
# these modules not typed and don't have stubs
module = [
"matplotlib",
]
ignore_missing_imports = true
44 changes: 25 additions & 19 deletions python/usearch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
import sys
import ctypes
import os
import platform
import warnings
import sys
import urllib.request
from typing import Optional, Tuple
import warnings
from urllib.error import HTTPError

#! Load NumKong before the USearch compiled module
Expand All @@ -30,14 +29,12 @@
pass # If the user doesn't want NumKong, we assume they know what they're doing


from usearch.compiled import (
VERSION_MAJOR,
VERSION_MINOR,
VERSION_PATCH,
from usearch.compiled import ( # type: ignore[import-not-found]
# Default values:
DEFAULT_CONNECTIVITY,
DEFAULT_EXPANSION_ADD,
DEFAULT_EXPANSION_SEARCH,
USES_FP16LIB,
# Dependencies:
USES_OPENMP,
USES_NUMKONG,
Expand All @@ -53,7 +50,7 @@


class BinaryManager:
def __init__(self, version: Optional[str] = None):
def __init__(self, version: str | None = None):
if version is None:
version = __version__
self.version = version or __version__
Expand All @@ -76,7 +73,7 @@ def determine_download_url(version: str, filename: str) -> str:
url = f"{base_url}/v{version}/{filename}"
return url

def get_binary_name(self) -> Tuple[str, str]:
def get_binary_name(self) -> tuple[str, str]:
version = self.version
os_map = {"Linux": "linux", "Windows": "windows", "Darwin": "macos"}
arch_map = {
Expand All @@ -89,12 +86,14 @@ def get_binary_name(self) -> Tuple[str, str]:
os_part = os_map.get(platform.system(), "")
arch = platform.machine()
arch_part = arch_map.get(arch, "")
extension = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}.get(platform.system(), "")
extension = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}.get(
platform.system(), ""
)
source_filename = f"usearch_sqlite_{os_part}_{arch_part}_{version}.{extension}"
target_filename = f"usearch_sqlite.{extension}"
return source_filename, target_filename

def sqlite_found_or_downloaded(self) -> Optional[str]:
def sqlite_found_or_downloaded(self) -> str | None:
"""
Attempts to locate the pre-installed `usearch_sqlite` binary.
If not found, downloads it from GitHub.
Expand All @@ -108,7 +107,6 @@ def sqlite_found_or_downloaded(self) -> Optional[str]:

# Check local development directories first
for local_dir in local_dirs:

local_path = os.path.join(local_dir, target_filename)
if os.path.exists(local_path):
path_wout_extension, _, _ = local_path.rpartition(".")
Expand All @@ -124,31 +122,39 @@ def sqlite_found_or_downloaded(self) -> Optional[str]:
download_dir = self.determine_download_dir()
local_path = os.path.join(download_dir, target_filename)
if not os.path.exists(local_path):

# If not found locally, warn the user and download from GitHub
warnings.warn("Will download `usearch_sqlite` binary from GitHub.", UserWarning)
warnings.warn(
"Will download `usearch_sqlite` binary from GitHub.", UserWarning
)
try:
source_url = self.determine_download_url(self.version, source_filename)
os.makedirs(download_dir, exist_ok=True)
urllib.request.urlretrieve(source_url, local_path)
except HTTPError as e:
# If the download fails due to HTTPError (e.g., 404 Not Found), like a missing lib version
if e.code == 404:
warnings.warn(f"Download failed: {e.url} could not be found.", UserWarning)
warnings.warn(
f"Download failed: {e.url} could not be found.", UserWarning
)
else:
warnings.warn(f"Download failed with HTTP error: {e.code} {e.reason}", UserWarning)
warnings.warn(
f"Download failed with HTTP error: {e.code} {e.reason}",
UserWarning,
)
return None

# Handle the case where binary_path does not exist after supposed successful download
if os.path.exists(local_path):
path_wout_extension, _, _ = local_path.rpartition(".")
return path_wout_extension
else:
warnings.warn("Failed to download `usearch_sqlite` binary from GitHub.", UserWarning)
warnings.warn(
"Failed to download `usearch_sqlite` binary from GitHub.", UserWarning
)
return None


def sqlite_path(version: str = None) -> str:
def sqlite_path(version: str | None = None) -> str:
manager = BinaryManager(version=version)
result = manager.sqlite_found_or_downloaded()
if result is None:
Expand Down
68 changes: 37 additions & 31 deletions python/usearch/client.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import Union, Optional, List
from __future__ import annotations

from typing import TYPE_CHECKING, Any

import numpy as np
from ucall.client import Client
from ucall.client import Client # type: ignore[import-untyped]

if TYPE_CHECKING:
from numpy.typing import NDArray

from usearch.index import Matches
from usearch.index import BatchMatches, Matches


def _vector_to_ascii(vector: np.ndarray) -> Optional[str]:
def _vector_to_ascii(vector: NDArray[Any]) -> str | None:
if vector.dtype != np.int8 and vector.dtype != np.uint8 and vector.dtype != np.byte:
return None
if not np.all((vector >= 0) | (vector <= 100)):
Expand All @@ -21,10 +26,12 @@ def _vector_to_ascii(vector: np.ndarray) -> Optional[str]:


class IndexClient:
def __init__(self, uri: str = "127.0.0.1", port: int = 8545, use_http: bool = True) -> None:
def __init__(
self, uri: str = "127.0.0.1", port: int = 8545, use_http: bool = True
) -> None:
self.client = Client(uri=uri, port=port, use_http=use_http)

def add_one(self, key: int, vector: np.ndarray):
def add_one(self, key: int, vector: NDArray[Any]):
assert isinstance(key, int)
assert isinstance(vector, np.ndarray)
vector = vector.flatten()
Expand All @@ -34,57 +41,57 @@ def add_one(self, key: int, vector: np.ndarray):
else:
self.client.add_one(key=key, vectors=vector)

def add_many(self, keys: np.ndarray, vectors: np.ndarray):
def add_many(self, keys: NDArray[Any], vectors: NDArray[Any]):
assert isinstance(keys, int)
assert isinstance(vectors, np.ndarray)
assert keys.ndim == 1 and vectors.ndim == 2
assert keys.shape[0] == vectors.shape[0]
self.client.add_many(keys=keys, vectors=vectors)

def add(self, keys: Union[np.ndarray, int], vectors: np.ndarray):
def add(self, keys: NDArray[Any] | int, vectors: NDArray[Any]):
if isinstance(keys, int) or len(keys) == 1:
return self.add_one(keys, vectors)
return self.add_one(
int(keys) if isinstance(keys, np.ndarray) else keys, vectors
)
else:
return self.add_many(keys, vectors)

def search_one(self, vector: np.ndarray, count: int) -> Matches:
matches: List[dict] = []
def search_one(self, vector: NDArray[Any], count: int) -> Matches:
vector = vector.flatten()
ascii_vector = _vector_to_ascii(vector)
if ascii_vector:
matches = self.client.search_ascii(string=ascii_vector, count=count)
raw = self.client.search_ascii(string=ascii_vector, count=count)
else:
matches = self.client.search_one(vector=vector, count=count)
raw = self.client.search_one(vector=vector, count=count)

print(matches.data)
matches = matches.json
matches: list[dict] = raw.json

keys = np.array((1, count), dtype=np.uint32)
distances = np.array((1, count), dtype=np.float32)
counts = np.array((1), dtype=np.uint32)
keys = np.array(count, dtype=np.uint32)
distances = np.array(count, dtype=np.float32)
for col, result in enumerate(matches):
keys[0, col] = result["key"]
distances[0, col] = result["distance"]
counts[0] = len(matches)
keys[col] = result["key"]
distances[col] = result["distance"]

return keys, distances, counts
return Matches(keys=keys[: len(matches)], distances=distances[: len(matches)])

def search_many(self, vectors: np.ndarray, count: int) -> Matches:
def search_many(self, vectors: NDArray[Any], count: int) -> BatchMatches:
batch_size: int = vectors.shape[0]
list_of_matches: List[List[dict]] = self.client.search_many(vectors=vectors, count=count)
list_of_matches: list[list[dict]] = self.client.search_many(
vectors=vectors, count=count
)

keys = np.array((batch_size, count), dtype=np.uint32)
distances = np.array((batch_size, count), dtype=np.float32)
counts = np.array((batch_size), dtype=np.uint32)
keys = np.zeros((batch_size, count), dtype=np.uint32)
distances = np.zeros((batch_size, count), dtype=np.float32)
counts = np.zeros(batch_size, dtype=np.uint32)
for row, matches in enumerate(list_of_matches):
for col, result in enumerate(matches):
keys[row, col] = result["key"]
distances[row, col] = result["distance"]
counts[row] = len(results)
counts[row] = len(matches)

return keys, distances, counts
return BatchMatches(keys=keys, distances=distances, counts=counts)

def search(self, vectors: np.ndarray, count: int) -> Matches:
def search(self, vectors: NDArray[Any], count: int) -> Matches | BatchMatches:
if vectors.ndim == 1 or (vectors.ndim == 2 and vectors.shape[0] == 1):
return self.search_one(vectors, count)
else:
Expand Down Expand Up @@ -117,4 +124,3 @@ def save(self, path: str):
index = IndexClient()
index.add(42, np.array([0.4] * 256, dtype=np.float32))
results = index.search(np.array([0.4] * 256, dtype=np.float32), 10)
print(results)
Loading
Loading