diff --git a/src/ninetoothed/_cache.py b/src/ninetoothed/_cache.py new file mode 100644 index 0000000..e2c0ff8 --- /dev/null +++ b/src/ninetoothed/_cache.py @@ -0,0 +1,402 @@ +"""Unified cache infrastructure for ninetoothed. + +This module is the single source of truth for cache key derivation and the +two-tier (memory L1 + filesystem L2) Cache class. All ninetoothed cache +points should be built on top of it so the policy (content-sensitivity, +eviction, thread-safety, atomic disk writes) lives in one place. + +Public API: + - Cache: two-tier key/value cache. + - hash_tensor_signature: content-sensitive hash of a ninetoothed.Tensor. + - hash_function_source: content-sensitive hash of a callable. + - hash_value: generic repr-based hash. + - project_files_fingerprint: hash of all files in a directory. +""" + +import ast +import functools +import hashlib +import inspect +import json +import os +import pathlib +import tempfile +import threading +from typing import Any, Callable, Optional + +# Default cache root. Mirrors ninetoothed.generation.CACHE_DIR. +DEFAULT_CACHE_ROOT = pathlib.Path.home() / ".ninetoothed" + + +# ---------- Tensor fingerprint ---------- + + +def hash_tensor_signature(tensor) -> tuple: + """Content-stable fingerprint of a ninetoothed.Tensor. + + Captures the symbolic layout fields that affect code generation while + deliberately ignoring `Tensor.name`, whose auto-generated instance suffix + would otherwise make equivalent fresh Tensor objects miss the cache. + + Two Tensors that share the same structural layout hash equal, even if + they were constructed in different `ninetoothed.make()` calls. + """ + return ( + tensor.ndim, + tuple(_hash_symbolic_value(size) for size in tensor.shape), + _stable_repr(getattr(tensor, "dtype", None)), + tensor.jagged_dim, + _stable_repr(tensor.other), + bool(getattr(tensor, "constexpr", False)), + _stable_repr(getattr(tensor, "value", None)), + ) + + +def _hash_symbolic_value(value: Any) -> tuple: + """Stable representation for Tensor.shape entries.""" + if hasattr(value, "node"): + node = value.node + bounds = ( + getattr(value, "lower_bound", None), + getattr(value, "upper_bound", None), + getattr(value, "power_of_two", None), + ) + if isinstance(node, ast.Constant): + return ("constant", node.value, bounds) + if isinstance(node, ast.Name): + return ("symbol", bounds) + return ("expr", ast.dump(node, include_attributes=False), bounds) + return ("raw", _stable_repr(value)) + + +def _stable_repr(value: Any) -> str: + if isinstance(value, dict): + items = sorted((_stable_repr(k), _stable_repr(v)) for k, v in value.items()) + return "{" + ", ".join(f"{k}: {v}" for k, v in items) + "}" + if isinstance(value, tuple): + return "(" + ", ".join(_stable_repr(v) for v in value) + ")" + if isinstance(value, list): + return "[" + ", ".join(_stable_repr(v) for v in value) + "]" + if isinstance(value, set): + return "{" + ", ".join(sorted(_stable_repr(v) for v in value)) + "}" + if inspect.isfunction(value): + return _function_payload(value, depth=1, seen=set()).hex() + return repr(value) + + +# ---------- Function fingerprint ---------- + + +def hash_function_source(func) -> str: + """Content-sensitive SHA256 hash of a callable. + + `functools.partial` is unwrapped: the inner function's source is hashed + together with `repr()` of the bound args/kwargs, so + `partial(jagged_dim=1)` differs from `partial(jagged_dim=2)`. + + Falls back to a stable `id:` token (module, qualname, id) when + `inspect.getsource` fails (lambdas, C builtins, REPL-defined code). + + Returns a 64-char hex digest prefixed with `src:` or `id:`. + """ + func, partial_args, partial_kwargs = _unwrap_partial(func) + + payload, prefix = _function_payload(func, depth=0, seen=set(), with_prefix=True) + h = hashlib.sha256() + h.update(payload) + h.update(_stable_repr(partial_args).encode("utf-8")) + h.update(_stable_repr(partial_kwargs).encode("utf-8")) + return prefix + h.hexdigest() + + +def _unwrap_partial(func): + layers = [] + while isinstance(func, functools.partial): + layers.append(func) + func = func.func + + args = [] + kwargs = {} + for layer in reversed(layers): + args.extend(layer.args) + if layer.keywords: + kwargs.update(layer.keywords) + + return func, tuple(args), tuple(sorted(kwargs.items())) + + +def _function_payload(func, depth=0, seen=None, with_prefix=False): + if seen is None: + seen = set() + + obj_id = id(func) + if obj_id in seen: + payload = ( + f"recursive:{getattr(func, '__module__', '?')}." + f"{getattr(func, '__qualname__', '?')}" + ).encode("utf-8") + return (payload, "id:") if with_prefix else payload + + seen.add(obj_id) + parts = [ + ("module", getattr(func, "__module__", "?")), + ("qualname", getattr(func, "__qualname__", "?")), + ] + + try: + src = inspect.getsource(func) + except (OSError, TypeError): + src = None + + prefix = "src:" if src is not None else "id:" + parts.append(("source", src)) + code = getattr(func, "__code__", None) + if code is not None: + parts.append(("code", code.co_code.hex())) + parts.append(("defaults", _stable_repr(getattr(func, "__defaults__", None)))) + parts.append( + ("kwdefaults", _stable_repr(getattr(func, "__kwdefaults__", None))) + ) + parts.append(("closure", _closure_payload(func))) + if depth < 2: + globals_ = getattr(func, "__globals__", {}) + for name in sorted(code.co_names): + if name not in globals_: + continue + value = globals_[name] + if inspect.isfunction(value): + parts.append( + ( + "global_func", + name, + _function_payload(value, depth=depth + 1, seen=seen).hex(), + ) + ) + elif isinstance(value, (str, int, float, bool, type(None), tuple)): + parts.append(("global_value", name, _stable_repr(value))) + + payload = _stable_repr(tuple(parts)).encode("utf-8") + return (payload, prefix) if with_prefix else payload + + +def _closure_payload(func) -> tuple: + closure = getattr(func, "__closure__", None) + if not closure: + return () + values = [] + for cell in closure: + try: + values.append(_stable_repr(cell.cell_contents)) + except ValueError: + values.append("") + return tuple(values) + + +def hash_value(value: Any) -> str: + """Generic repr-based hash, stable for any repr-able Python value.""" + return hashlib.sha256(repr(value).encode("utf-8")).hexdigest() + + +# ---------- Cache class ---------- + + +class Cache: + """Two-tier cache: in-memory L1 (FIFO eviction) + filesystem L2 (optional). + + Usage: + cache = Cache(namespace="auto_tuning", suffix=".json") + value = cache.get(key, default=None) + if value is None: + value = ...expensive computation... + cache.put(key, value) + + Disk layout: /. Disk format + controlled by `serializer` / `deserializer` (default JSON). + + Pass neither `cache_dir` nor `namespace` to get a memory-only cache + (no disk writes -- useful for per-process JIT handle caches whose + values are not serializable). + + Thread-safe. + """ + + def __init__( + self, + namespace: Optional[str] = None, + *, + suffix: str = ".json", + serializer: Optional[Callable[[Any], str]] = None, + deserializer: Optional[Callable[[str], Any]] = None, + cache_dir: Optional[pathlib.Path] = None, + max_memory: int = 256, + ): + self._suffix = suffix + self._serializer = serializer if serializer is not None else json.dumps + self._deserializer = deserializer if deserializer is not None else json.loads + self._max_memory = max_memory + + # Resolve disk directory. + if cache_dir is not None: + self._cache_dir = pathlib.Path(cache_dir) + elif namespace is not None: + self._cache_dir = DEFAULT_CACHE_ROOT / namespace + else: + self._cache_dir = None # memory-only mode + + if self._cache_dir is not None: + self._cache_dir.mkdir(parents=True, exist_ok=True) + + self._mem: dict = {} + self._lock = threading.Lock() + + # ----- introspection ----- + + @property + def cache_dir(self) -> Optional[pathlib.Path]: + """Disk directory backing this cache, or None for memory-only.""" + return self._cache_dir + + @property + def is_memory_only(self) -> bool: + return self._cache_dir is None + + @property + def memory_size(self) -> int: + with self._lock: + return len(self._mem) + + # ----- key -> path ----- + + def _path_for(self, key: Any) -> Optional[pathlib.Path]: + if self._cache_dir is None: + return None + h = hashlib.sha256(repr(key).encode("utf-8")).hexdigest() + return self._cache_dir / (h + self._suffix) + + # ----- core API ----- + + def contains(self, key: Any) -> bool: + """True iff `key` is in L1 or L2.""" + with self._lock: + if key in self._mem: + return True + path = self._path_for(key) + return path is not None and path.exists() + + def get(self, key: Any, default: Any = None) -> Any: + """L1 (mem) then L2 (disk). L2 hits are promoted to L1.""" + with self._lock: + if key in self._mem: + return self._mem[key] + + path = self._path_for(key) + if path is None or not path.exists(): + return default + + try: + value = self._deserializer(path.read_text(encoding="utf-8")) + except (OSError, ValueError, json.JSONDecodeError): + return default + + with self._lock: + if len(self._mem) >= self._max_memory: + self._evict_one_unlocked() + self._mem[key] = value + return value + + def put(self, key: Any, value: Any) -> None: + """Write L1 (mem) and L2 (disk). Disk failure leaves L1 intact. + + Disk writes are atomic: serialize to a sibling `.tmp` file, fsync, + then rename over the target. This prevents a concurrent reader from + observing a half-written value if the writer is killed mid-write + (e.g. crash, OOM kill, or power loss). + """ + with self._lock: + if key not in self._mem and len(self._mem) >= self._max_memory: + self._evict_one_unlocked() + self._mem[key] = value + + path = self._path_for(key) + if path is not None: + self._atomic_write(path, value) + + def _atomic_write(self, path: pathlib.Path, value: Any) -> None: + """Serialize, write to `.tmp`, fsync, then rename over `path`. + + Failure modes (OSError, TypeError, ValueError from serializer) leave + both L1 and the on-disk state intact -- no half-written file is + visible to readers, and the L1 entry is still authoritative in-process. + """ + try: + payload = self._serializer(value) + except (TypeError, ValueError): + return + tmp = None + try: + with tempfile.NamedTemporaryFile( + "w", + encoding="utf-8", + dir=path.parent, + prefix=path.name + ".", + suffix=".tmp", + delete=False, + ) as f: + tmp = pathlib.Path(f.name) + f.write(payload) + f.flush() + os.fsync(f.fileno()) + os.replace(tmp, path) + except OSError: + # Best-effort cleanup of the leftover .tmp file. + try: + if tmp is not None and tmp.exists(): + tmp.unlink() + except OSError: + pass + + def _evict_one_unlocked(self) -> None: + if self._max_memory > 0 and self._mem: + self._mem.pop(next(iter(self._mem))) + + def delete(self, key: Any) -> None: + """Remove from L1 and L2. Missing entries are silently ignored.""" + with self._lock: + self._mem.pop(key, None) + path = self._path_for(key) + if path is not None and path.exists(): + try: + path.unlink() + except OSError: + pass + + def clear_memory(self) -> None: + """Clear L1 only; L2 (disk) is preserved for cross-process sharing.""" + with self._lock: + self._mem.clear() + + +# ---------- Project fingerprint ---------- + + +def project_files_fingerprint( + directory: pathlib.Path, exclude_suffixes=(".pyc",) +) -> str: + """SHA256 fingerprint over all files under `directory`. + + Used to namespace caches by the ninetoothed installation version: when + the package code changes, the fingerprint changes and old cache files + are effectively ignored (different subdir, different key prefix). + + `rglob` is sorted for determinism; `.pyc` is excluded by default. + """ + h = hashlib.sha256() + paths = sorted( + p + for p in pathlib.Path(directory).rglob("*") + if p.is_file() and p.suffix not in exclude_suffixes + ) + for p in paths: + h.update(str(p.relative_to(directory)).encode("utf-8")) + h.update(p.read_bytes()) + return h.hexdigest() diff --git a/src/ninetoothed/auto_tuner.py b/src/ninetoothed/auto_tuner.py index 337a1a6..1ab0133 100644 --- a/src/ninetoothed/auto_tuner.py +++ b/src/ninetoothed/auto_tuner.py @@ -1,9 +1,16 @@ -import hashlib -import json +"""Auto-tuner for ninetoothed kernels. + +Migrated to the unified Cache API (ninetoothed._cache.Cache). All timings +are stored as a single JSON per (project, triton-version) directory; the +prior per-func split-file layout is gone -- users with existing caches +should run `rm -rf ~/.ninetoothed/auto_tuning/`. +""" + import os import triton +from ninetoothed._cache import Cache, project_files_fingerprint from ninetoothed.aot import _KernelLaunchError from ninetoothed.generation import CACHE_DIR @@ -16,20 +23,23 @@ def __init__(self, funcs, keys): self._func_to_key = {func: key for func, key in zip(self._funcs, self._keys)} - self._cache_dir = ( - _AUTO_TUNING_CACHE_DIR - / f"{_project_key()}_triton_{triton.__version__.replace('.', '_')}" - ) - self._cache_dir.mkdir(parents=True, exist_ok=True) + # Disk layout: /auto_tuning/_triton_/ + # The project_key isolates caches across ninetoothed versions. + subdir = f"{_project_key()}_triton_{triton.__version__.replace('.', '_')}" + disk_dir = CACHE_DIR / "auto_tuning" / subdir - auto_tuner_key = tuple(self._keys) - cache_key = hashlib.sha256(str(auto_tuner_key).encode("utf-8")).hexdigest() - self._cache_path = self._cache_dir / f"{cache_key}.json" + self._cache = Cache( + cache_dir=disk_dir, + suffix=".json", + max_memory=64, + ) - if self._cache_path.exists(): - self._timings = json.loads(self._cache_path.read_text()) - else: - self._timings = {key: {} for key in self._keys} + # The full timings dict is stored under a single sentinel key. + self._disk_key = ("_all_timings_",) + loaded = self._cache.get(self._disk_key, default={}) + if not loaded: + loaded = {key: {} for key in self._keys} + self._timings = loaded self._best_func = {} @@ -54,9 +64,7 @@ def _get_timings(self, args, kwargs): timings = [self._get_timing(func, args, kwargs) for func in self._funcs] self._timings[arg_key] = timings - - self._cache_path.write_text(json.dumps(self._timings)) - + self._save() return timings def _get_timing(self, func, args, kwargs): @@ -67,31 +75,18 @@ def _get_timing(self, func, args, kwargs): if (arg_key := type(self)._make_arg_key(args, kwargs)) in data: return data[arg_key] - cache_path = self._get_func_cache_path(func) - - if cache_path.exists(): - data |= json.loads(cache_path.read_text()) - - if arg_key in data: - return data[arg_key] - try: timing = triton.testing.do_bench(lambda: func(*args, **kwargs)) except _KernelLaunchError: timing = float("inf") data[arg_key] = timing - - cache_path.write_text(json.dumps(data)) - + self._save() return timing - def _get_func_cache_path(self, func): - func_key = self._func_to_key[func] - cache_key = hashlib.sha256(str(func_key).encode("utf-8")).hexdigest() - cache_path = self._cache_dir / f"{cache_key}.json" - - return cache_path + def _save(self): + """Persist the full timings dict (L1 + L2).""" + self._cache.put(self._disk_key, self._timings) @staticmethod def _make_arg_key(args, kwargs): @@ -118,35 +113,7 @@ def _make_tensor_key(tensor): return f"tensor(shape={tuple(tensor.shape)}, dtype={str(tensor.dtype).split('.')[-1]})" -_AUTO_TUNING_CACHE_DIR = CACHE_DIR / "auto_tuning" - -_FILE_PATH = os.path.abspath(__file__) - -_PARENT_DIR = os.path.dirname(_FILE_PATH) - - def _project_key(): - consolidated_hash = hashlib.sha256() - - for dirpath, dirnames, filenames in os.walk(_PARENT_DIR): - dirnames.sort() - filenames.sort() - - for filename in filenames: - file_path = os.path.join(dirpath, filename) - - if ( - not os.path.isfile(file_path) - or os.path.splitext(file_path)[1] == ".pyc" - ): - continue - - file_hash = _calculate_file_hash(file_path) - consolidated_hash.update(file_hash.encode("utf-8")) - - return consolidated_hash.hexdigest() - - -def _calculate_file_hash(file_path): - with open(file_path, "rb") as f: - return hashlib.sha256(f.read()).hexdigest() + """Fingerprint of the ninetoothed source tree, used to namespace caches + across ninetoothed installation versions.""" + return project_files_fingerprint(os.path.dirname(os.path.abspath(__file__))) diff --git a/src/ninetoothed/debugging.py b/src/ninetoothed/debugging.py index d977142..693d7a7 100644 --- a/src/ninetoothed/debugging.py +++ b/src/ninetoothed/debugging.py @@ -59,7 +59,7 @@ def _arrangement(*tensors): application_source = _generate_debug_application_source(tensors, debug_tensors) - source_file = str(cache_source(application_source)) + source_file = str(cache_source(application_source, "_debug_")) module = import_from_path(source_file, source_file) module_vars = vars(module) diff --git a/src/ninetoothed/generation.py b/src/ninetoothed/generation.py index 6ea1d0f..7561925 100644 --- a/src/ninetoothed/generation.py +++ b/src/ninetoothed/generation.py @@ -124,7 +124,7 @@ def _find_dependencies(func): ["ruff", "format", "-"], input=source, encoding="utf-8" ) - cache_file = cache_source(source) + cache_file = cache_source(source, kernel_name) self.tensors = self._args self.kernel_func = self._func_def @@ -870,8 +870,17 @@ def visit_Call(self, node): return node -def cache_source(source): - digest = hashlib.sha256(source.encode("utf-8")).hexdigest() +def cache_source(source, kernel_name): + # Mix kernel_name into the digest so different kernels derived from + # the same source text (e.g. two block_size configs of the same + # arrangement) do not collide on a single .py file. Without this, + # concurrent AOT compilations can race-write the same cache file, + # leaving triton.tools.compile unable to find the named kernel. + hasher = hashlib.sha256() + hasher.update(source.encode("utf-8")) + hasher.update(b"\0") + hasher.update(kernel_name.encode("utf-8")) + digest = hasher.hexdigest() cache_file = CACHE_DIR / f"{digest}.py" if not cache_file.exists(): diff --git a/src/ninetoothed/make.py b/src/ninetoothed/make.py index 0d09826..08f321f 100644 --- a/src/ninetoothed/make.py +++ b/src/ninetoothed/make.py @@ -1,7 +1,53 @@ +"""Public entry point: ninetoothed.make(), with content-sensitive handle cache. + +The handle cache (L1, in-process, FIFO) is keyed by a content hash of the +arrangement + application source code, tensor structural signatures, and +compilation parameters. Editing the user-facing functions invalidates the +cache; editing unrelated code does not. +""" + import inspect +from ninetoothed._cache import Cache, hash_function_source, hash_tensor_signature from ninetoothed.aot import aot from ninetoothed.jit import jit +from ninetoothed.tensor import Tensor + + +def _build_cache_key( + arrangement, + application, + tensors, + caller, + kernel_name, + num_warps, + num_stages, + max_num_configs, +): + def _hash_one(t): + # Tensor instances get content-sensitive structural hashing. + if isinstance(t, Tensor): + return hash_tensor_signature(t) + # Non-Tensor elements (slices, ints, lists, etc. used as + # arrangement() kwargs) are hashed via repr() so they + # correctly participate in the cache key. + return ("__raw__", repr(t)) + + return ( + hash_function_source(arrangement), + hash_function_source(application), + tuple(_hash_one(t) for t in tensors), + caller, + kernel_name, + num_warps, + num_stages, + max_num_configs, + ) + + +# Per-process L1 cache for JIT handles. Not shared across processes +# (handles are not serializable). 256-entry FIFO matches prior behavior. +_HANDLE_CACHE = Cache(max_memory=256) def make( @@ -24,12 +70,30 @@ def make( :param kernel_name: The name for the generated kernel. :param output_dir: The directory to store the generated files. :param num_warps: The number of warps to use. - :param num_stages: The number of pipeline stages. + :param num_stages: The number of stages to use. :param max_num_configs: The maximum number of auto-tuning configurations to use. :return: A handle to the compute kernel. """ + # Cache only the JIT ("torch") path. The AOT path produces on-disk + # build artifacts (.so, .csv, .fingerprint) that are managed by + # build.py's own cache. + if caller == "torch": + key = _build_cache_key( + arrangement, + application, + tensors, + caller, + kernel_name, + num_warps, + num_stages, + max_num_configs, + ) + cached = _HANDLE_CACHE.get(key) + if cached is not None: + return cached + params = inspect.signature(application).parameters types = arrangement(*tensors) types = types if isinstance(types, tuple) else (types,) @@ -37,7 +101,7 @@ def make( application.__annotations__ = annotations if caller == "torch": - return jit( + handle = jit( application, caller=caller, kernel_name=kernel_name, @@ -45,6 +109,8 @@ def make( num_stages=num_stages, max_num_configs=max_num_configs, ) + _HANDLE_CACHE.put(key, handle) + return handle return aot( application, diff --git a/tests/test_atomic_write.py b/tests/test_atomic_write.py new file mode 100644 index 0000000..1b5e217 --- /dev/null +++ b/tests/test_atomic_write.py @@ -0,0 +1,143 @@ +"""Tests for atomic disk writes in ninetoothed._cache.Cache.put. + +Verifies the contract: + - Successful put leaves exactly one file at the target path (no .tmp residue). + - A mid-write failure (simulated by patching the file system) leaves the + previous file intact and no half-written replacement. + - The in-process L1 cache is unaffected by disk failure. +""" + +import json +import multiprocessing + +from ninetoothed._cache import Cache + + +def test_successful_put_leaves_no_tmp_file(tmp_path): + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + c.put("k", {"v": 1}) + + # Exactly one .json file, no .tmp residue + files = list(tmp_path.glob("*")) + assert len(files) == 1 + assert files[0].suffix == ".json" + assert not list(tmp_path.glob("*.tmp")) + + +def test_put_overwrites_existing_file_atomically(tmp_path): + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + c.put("k", "first") + c.put("k", "second") + + files = list(tmp_path.glob("*.json")) + assert len(files) == 1 + assert json.loads(files[0].read_text()) == "second" + + +def test_disk_failure_preserves_l1(tmp_path): + """If the serializer raises, L1 must still have the value.""" + + def bad_serializer(_): + raise TypeError("nope") + + c = Cache( + cache_dir=tmp_path, + suffix=".json", + serializer=bad_serializer, + max_memory=4, + ) + c.put("k", "v") + # L1 was updated before the (failing) disk write, so it's still in memory + assert c.get("k") == "v" + + +def test_disk_failure_leaves_no_tmp_residue(tmp_path): + def bad_serializer(_): + raise TypeError("nope") + + c = Cache( + cache_dir=tmp_path, + suffix=".json", + serializer=bad_serializer, + max_memory=4, + ) + c.put("k", "v") + assert not list(tmp_path.glob("*.tmp")) + assert not list(tmp_path.glob("*.json")) + + +def test_oserror_during_rename_keeps_old_file(tmp_path): + """Simulate a rename() failure: original file must survive.""" + # First, write a real value through the cache to establish a real file. + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + c.put("k", "original") + + path = c._path_for("k") + assert path.read_text() # file exists + + # Now patch os.replace to fail + import ninetoothed._cache as cache_mod + + original_replace = cache_mod.os.replace + + def fail_replace(src, dst): + raise OSError("simulated rename failure") + + cache_mod.os.replace = fail_replace + try: + c.put("k", "new_value") + finally: + cache_mod.os.replace = original_replace + + # The on-disk file should still contain the ORIGINAL value + assert json.loads(path.read_text()) == "original" + # And no .tmp residue + assert not list(tmp_path.glob("*.tmp")) + + +def test_no_tmp_residue_under_concurrent_writes(tmp_path): + """Many threads writing distinct keys should not leave .tmp files lying around.""" + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + + def writer(i): + for j in range(20): + c.put(f"k_{i}_{j}", j) + + import threading + + threads = [threading.Thread(target=writer, args=(i,)) for i in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Every key persisted; no .tmp residue + assert not list(tmp_path.glob("*.tmp")) + json_files = list(tmp_path.glob("*.json")) + assert len(json_files) > 0 + + # A fresh instance can read them all + c2 = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + for i in range(4): + for j in range(20): + assert c2.get(f"k_{i}_{j}") == j + + +def _process_writer(args): + cache_dir, value = args + c = Cache(cache_dir=cache_dir, suffix=".json", max_memory=1) + for _ in range(50): + c.put("shared", {"v": value}) + + +def test_process_safe_writes_same_key(tmp_path): + ctx = multiprocessing.get_context("spawn") + values = tuple(range(4)) + with ctx.Pool(processes=len(values)) as pool: + pool.map(_process_writer, [(tmp_path, value) for value in values]) + + assert not list(tmp_path.glob("*.tmp")) + + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=1) + stored = c.get("shared") + assert stored in [{"v": value} for value in values] diff --git a/tests/test_auto_tuner.py b/tests/test_auto_tuner.py index 69bcb21..7d026f2 100644 --- a/tests/test_auto_tuner.py +++ b/tests/test_auto_tuner.py @@ -1,22 +1,33 @@ import time import pytest +import torch from ninetoothed.auto_tuner import AutoTuner from tests.utils import get_available_devices +@pytest.fixture +def auto_tuner_factory(tmp_path, monkeypatch): + """Factory: each call yields a fresh AutoTuner backed by tmp_path. + + Avoids polluting ~/.ninetoothed across pytest runs / parametrize rows. + """ + import ninetoothed.auto_tuner as at_mod + + monkeypatch.setattr(at_mod, "CACHE_DIR", tmp_path / ".ninetoothed") + (tmp_path / ".ninetoothed").mkdir() + yield lambda: AutoTuner((_foo, _bar), (_foo.__name__, _bar.__name__)) + + @pytest.mark.parametrize("_", get_available_devices()) @pytest.mark.parametrize("kwargs", ({"a": 2, "b": 4}, {"a": 2, "b": 4, "c": 6, "d": 8})) @pytest.mark.parametrize("args", ((1,), (1, 3, 5))) -def test_auto_tuner(args, kwargs, _): - auto_tuner = AutoTuner((_foo, _bar), (_foo.__name__, _bar.__name__)) +def test_auto_tuner(args, kwargs, _, auto_tuner_factory): + auto_tuner = auto_tuner_factory() - assert not auto_tuner._get_func_cache_path(_foo).exists() - - assert not auto_tuner._get_func_cache_path(_bar).exists() - - assert not auto_tuner._cache_path.exists() + # Initial state: timings dict is empty (fresh cache). + assert auto_tuner._timings == {key: {} for key in auto_tuner._keys} first_time_start_time = time.perf_counter() @@ -26,11 +37,12 @@ def test_auto_tuner(args, kwargs, _): first_time_elapsed_time = first_time_end_time - first_time_start_time - assert auto_tuner._get_func_cache_path(_foo).exists() - - assert auto_tuner._get_func_cache_path(_bar).exists() + # After benchmarking, timings should be populated for every func + arg. + arg_key = auto_tuner._make_arg_key(args, kwargs) - assert auto_tuner._cache_path.exists() + for func_key in auto_tuner._keys: + assert arg_key in auto_tuner._timings[func_key] + assert arg_key in auto_tuner._best_func second_time_start_time = time.perf_counter() @@ -40,15 +52,10 @@ def test_auto_tuner(args, kwargs, _): second_time_elapsed_time = second_time_end_time - second_time_start_time + # Cached second call must be substantially faster. assert second_time_elapsed_time < first_time_elapsed_time - auto_tuner._get_func_cache_path(_foo).unlink() - - auto_tuner._get_func_cache_path(_bar).unlink() - - auto_tuner._cache_path.unlink() - - best_func = auto_tuner._best_func[auto_tuner._make_arg_key(args, kwargs)] + best_func = auto_tuner._best_func[arg_key] if _foo_delay(*args, **kwargs) < _bar_delay(*args, **kwargs): assert best_func is _foo @@ -56,6 +63,22 @@ def test_auto_tuner(args, kwargs, _): assert best_func is _bar +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="auto_tuner uses triton.testing.do_bench which requires a CUDA driver", +) +def test_auto_tuner_persists_across_instances(auto_tuner_factory): + """Re-instantiation should load timings from disk, skipping re-benchmark.""" + tuner1 = auto_tuner_factory() + tuner1(1, 2, 3, a=4) + + tuner2 = auto_tuner_factory() + + # Both instances see the same persisted timings loaded from disk. + assert tuner2._timings == tuner1._timings + assert tuner2._timings # not empty + + def _foo_delay(*args, **kwargs): return 0.001 * (2 * len(args) + len(kwargs)) diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..297f82a --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,382 @@ +"""Unit tests for ninetoothed._cache.Cache. + +Covers the public surface of the Cache class: + - in-memory mode (cache_dir=None) + - disk-backed mode (cache_dir=tmp_path) + - FIFO eviction at max_memory limit + - L2 -> L1 promotion on hit + - clear_memory() preserves L2 + - thread-safety (concurrent put/get from multiple threads) + - contains() reflects L1 + L2 + - integration: make() cache key is sensitive to non-Tensor + elements in the tensors tuple (e.g. ceil_mode), even when the + output shape happens to be the same for both values. +""" + +import threading + +import pytest +import torch +import torch.nn.functional as F + +from ninetoothed._cache import Cache +from ninetoothed.make import _HANDLE_CACHE +from tests.test_max_pool2d import max_pool2d +from tests.utils import get_available_devices + +# ---------- in-memory mode ---------- + + +def test_memory_only_put_and_get(): + c = Cache(max_memory=16) + c.put("k1", {"v": 1}) + assert c.get("k1") == {"v": 1} + + +def test_memory_only_get_missing_returns_default(): + c = Cache(max_memory=16) + assert c.get("missing") is None + assert c.get("missing", default="fallback") == "fallback" + + +def test_memory_only_is_memory_only_property(): + c = Cache(max_memory=16) + assert c.is_memory_only is True + assert c.cache_dir is None + + +def test_memory_only_does_not_persist_across_instances(): + c1 = Cache(max_memory=16) + c1.put("k", "v") + c2 = Cache(max_memory=16) + assert c2.get("k") is None # not shared across instances + + +# ---------- disk-backed mode ---------- + + +def test_disk_backed_persists(tmp_path): + disk = tmp_path / "cache" + c1 = Cache(cache_dir=disk, suffix=".json", max_memory=16) + c1.put("k1", [1, 2, 3]) + assert c1.get("k1") == [1, 2, 3] + + c2 = Cache(cache_dir=disk, suffix=".json", max_memory=16) + assert c2.get("k1") == [1, 2, 3] # reloaded from disk + + +def test_disk_backed_creates_directory(tmp_path): + disk = tmp_path / "deep" / "nested" / "cache" + Cache(cache_dir=disk, suffix=".json", max_memory=4) + assert disk.exists() + + +def test_disk_backed_is_not_memory_only(tmp_path): + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + assert c.is_memory_only is False + assert c.cache_dir == tmp_path + + +def test_disk_backed_default_serializer_is_json(tmp_path): + import json + + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + c.put("k", {"a": 1}) + files = list(tmp_path.glob("*.json")) + assert len(files) == 1 + assert json.loads(files[0].read_text()) == {"a": 1} + + +def test_disk_backed_custom_serializer(tmp_path): + """identity serializer for plain strings; file is the string verbatim.""" + c = Cache( + cache_dir=tmp_path, + suffix=".txt", + serializer=lambda v: v, + deserializer=lambda s: s, + max_memory=4, + ) + c.put("k", "hello world") + files = list(tmp_path.glob("*.txt")) + assert len(files) == 1 + assert files[0].read_text() == "hello world" + assert c.get("k") == "hello world" + + +# ---------- L1 + L2 promotion ---------- + + +def test_l2_hit_promotes_to_l1(tmp_path): + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + c.put("k", "v1") + c.clear_memory() + assert c.memory_size == 0 + # get() should pull from disk and promote to L1 + assert c.get("k") == "v1" + assert c.memory_size == 1 + + +def test_clear_memory_preserves_disk(tmp_path): + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + c.put("k", "v1") + c.clear_memory() + # New instance can still read the value from disk + c2 = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + assert c2.get("k") == "v1" + + +def test_contains_reflects_l1_and_l2(tmp_path): + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + assert c.contains("k") is False + c.put("k", "v") + assert c.contains("k") is True + c.clear_memory() + assert c.contains("k") is True # still in L2 + + +# ---------- FIFO eviction ---------- + + +def test_fifo_eviction_at_max_memory(): + c = Cache(max_memory=3) + c.put("a", 1) + c.put("b", 2) + c.put("c", 3) + assert c.memory_size == 3 + # 4th insert should evict "a" (FIFO) + c.put("d", 4) + assert c.memory_size == 3 + assert c.get("a") is None + assert c.get("b") == 2 + assert c.get("c") == 3 + assert c.get("d") == 4 + + +def test_updating_existing_key_does_not_evict_other_entries(): + c = Cache(max_memory=2) + c.put("a", 1) + c.put("b", 2) + + c.put("b", 3) + + assert c.memory_size == 2 + assert c.get("a") == 1 + assert c.get("b") == 3 + + +def test_fifo_eviction_with_disk(tmp_path): + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=2) + c.put("a", 1) + c.put("b", 2) + c.put("c", 3) + # "a" was evicted from L1 + assert c.get("a") == 1 # should be promoted from L2 again + # After this promotion, L1 holds {b, a} (insertion order); "c" still in L1 + assert c.memory_size == 2 + + +# ---------- delete ---------- + + +def test_delete_removes_from_l1_and_l2(tmp_path): + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + c.put("k", "v") + assert c.contains("k") is True + c.delete("k") + assert c.contains("k") is False + assert c.get("k") is None + + +def test_delete_missing_key_is_noop(): + c = Cache(max_memory=4) + c.delete("never_existed") # should not raise + + +# ---------- disk corruption / bad data ---------- + + +def test_corrupt_disk_file_returns_default(tmp_path): + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + c.put("k", "v") + # Corrupt the file on disk + path = c._path_for("k") + path.write_text("this is not valid json {{{") + c.clear_memory() + # get() should swallow the JSONDecodeError and return default + assert c.get("k") is None + assert c.get("k", default="safe") == "safe" + + +# ---------- thread safety ---------- + + +def test_concurrent_put_get_no_race(tmp_path): + """Many threads putting + getting on disjoint keys must not corrupt state.""" + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=1024) + + errors = [] + + def worker(i): + try: + for j in range(50): + c.put(f"k_{i}_{j}", j) + v = c.get(f"k_{i}_{j}") + assert v == j + except Exception as e: # noqa: BLE001 + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(8)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors, f"Thread errors: {errors}" + + +def test_concurrent_eviction_under_contention(): + """Stress eviction: many threads force frequent FIFO eviction.""" + c = Cache(max_memory=4) + c.put("seed", 0) + + errors = [] + + def worker(): + try: + for j in range(200): + c.put(f"k_{j}", j) + _ = c.get(f"k_{j}") + _ = c.memory_size + except Exception as e: # noqa: BLE001 + errors.append(e) + + threads = [threading.Thread(target=worker) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors, f"Thread errors: {errors}" + + +def test_concurrent_disk_writes_no_partial_files(tmp_path): + """After concurrent writes + force-flush, every persisted key must be readable.""" + c = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + + def writer(prefix, n): + for j in range(n): + c.put(f"{prefix}_{j}", {"v": j}) + # Force L1 eviction by over-filling + c.clear_memory() + c.get(f"{prefix}_{j}") # re-promote to L1, then we don't really need it + + threads = [threading.Thread(target=writer, args=(f"p{i}", 30)) for i in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Every key we wrote must round-trip from disk via a fresh instance + c2 = Cache(cache_dir=tmp_path, suffix=".json", max_memory=4) + for i in range(4): + for j in range(30): + assert c2.get(f"p{i}_{j}") == {"v": j} + + +# ---------- introspection ---------- + + +def test_memory_size_reports_l1_length(): + c = Cache(max_memory=16) + assert c.memory_size == 0 + c.put("a", 1) + assert c.memory_size == 1 + c.put("b", 2) + assert c.memory_size == 2 + + +# ---------- integration: make() cache key sensitivity ---------- +# These tests verify that _HANDLE_CACHE in ninetoothed.make produces +# DISTINCT cache keys for arrangements that differ only in non-Tensor +# elements of the tensors tuple (e.g. ceil_mode). This guards against +# regressions where the cache would mistake two semantically-different +# kernels for the same cache entry. +# +# The two interesting cases: +# (A) the differing argument produces a different output shape -- a +# shape-naive cache would still correctly miss here, so this is +# just a sanity check. +# (B) the differing argument produces an IDENTICAL output shape -- a +# shape-naive cache would mistakenly HIT and return the wrong +# kernel. This is the key regression test. +# (C) verify the cache HIT path: a second call with the same arguments +# must reuse the cached entry. + + +@pytest.fixture(autouse=True) +def _clear_handle_cache(): + """Each test starts (and ends) with an empty _HANDLE_CACHE L1.""" + _HANDLE_CACHE.clear_memory() + yield + _HANDLE_CACHE.clear_memory() + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_make_cache_distinguishes_ceil_mode_different_shapes(device): + """Sanity case: h=64, r=3 gives DIFFERENT output shapes for the two + ceil_mode values (False -> (21,21), True -> (22,22)). Cache must MISS + on the second call. (Even a shape-naive cache would miss here.)""" + torch.manual_seed(0) + x = torch.randn(32, 3, 64, 64, dtype=torch.float16, device=device) + + out_false = max_pool2d(x, (3, 3), ceil_mode=False) + out_true = max_pool2d(x, (3, 3), ceil_mode=True) + + assert out_false.shape == (32, 3, 21, 21) + assert out_true.shape == (32, 3, 22, 22) + assert torch.allclose(out_false, F.max_pool2d(x, (3, 3), ceil_mode=False)) + assert torch.allclose(out_true, F.max_pool2d(x, (3, 3), ceil_mode=True)) + # Two distinct ceil_mode values -> two L1 entries + assert _HANDLE_CACHE.memory_size == 2 + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_make_cache_distinguishes_ceil_mode_same_shape(device): + """Key regression test: h=63, r=3 gives IDENTICAL output shape (21,21) + for both ceil_mode values. A shape-naive cache would mistakenly HIT + on the second call and return the wrong kernel (the False kernel, + which uses floor_mode=True, when True was requested, which uses + floor_mode=False). + + The cache MUST still MISS: tensors tuple contains the raw bool + ceil_mode, whose repr() distinguishes True from False, and the + cache key is content-sensitive to that repr().""" + torch.manual_seed(0) + x = torch.randn(32, 3, 63, 63, dtype=torch.float16, device=device) + + out_false = max_pool2d(x, (3, 3), ceil_mode=False) + out_true = max_pool2d(x, (3, 3), ceil_mode=True) + + # Output shapes are identical -- shape alone cannot disambiguate + assert out_false.shape == out_true.shape == (32, 3, 21, 21) + # But values DO differ between the two kernels (different floor_mode + # in the arrangement), and each matches its reference + assert torch.allclose(out_false, F.max_pool2d(x, (3, 3), ceil_mode=False)) + assert torch.allclose(out_true, F.max_pool2d(x, (3, 3), ceil_mode=True)) + # Cache must hold 2 distinct entries (one per ceil_mode value) + assert _HANDLE_CACHE.memory_size == 2 + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_make_cache_reuses_unchanged_ceil_mode(device): + """Verify cache HIT path: a second call with the same ceil_mode + must reuse the cached entry (L1 size unchanged).""" + torch.manual_seed(0) + x = torch.randn(32, 3, 64, 64, dtype=torch.float16, device=device) + + out1 = max_pool2d(x, (3, 3), ceil_mode=False) + assert _HANDLE_CACHE.memory_size == 1 # first call: miss + put + + out2 = max_pool2d(x, (3, 3), ceil_mode=False) + assert _HANDLE_CACHE.memory_size == 1 # second call: hit, no new entry + assert torch.allclose(out1, out2) diff --git a/tests/test_make_cache_key.py b/tests/test_make_cache_key.py new file mode 100644 index 0000000..6815160 --- /dev/null +++ b/tests/test_make_cache_key.py @@ -0,0 +1,333 @@ +"""Unit tests for the key-derivation helpers in ninetoothed._cache. + +Covers: + - hash_function_source: source sensitivity, functools.partial, fallback. + - hash_tensor_signature: structural identity, NOT instance-bound names. + - hash_value: stability across calls. +""" + +import functools + +from ninetoothed._cache import ( + hash_function_source, + hash_tensor_signature, + hash_value, +) +from ninetoothed.tensor import Tensor + +# ---------- hash_function_source ---------- + + +def test_returns_string(): + def f(): + pass + + h = hash_function_source(f) + assert isinstance(h, str) + assert len(h) > 0 + + +def test_returns_src_prefix_for_normal_function(): + def f(): + return 1 + + h = hash_function_source(f) + assert h.startswith("src:") + + +def test_same_function_same_hash(): + def f(x): + return x + 1 + + assert hash_function_source(f) == hash_function_source(f) + + +def test_modified_function_different_hash(): + def f(x): + return x + 1 + + h1 = hash_function_source(f) + + def f(x): # noqa: F811 + return x + 2 + + h2 = hash_function_source(f) + assert h1 != h2 + + +def test_different_functions_different_hash(): + def f(): + return 1 + + def g(): + return 2 + + assert hash_function_source(f) != hash_function_source(g) + + +def test_functools_partial_unwrapped(): + """partial(jagged_dim=1) must differ from partial(jagged_dim=2).""" + + def base(x): + return x + + p1 = functools.partial(base, jagged_dim=1) + p2 = functools.partial(base, jagged_dim=2) + + assert hash_function_source(p1) != hash_function_source(p2) + + +def test_functools_partial_nested(): + """Nested partials (e.g. partial(partial(base, a=1), b=2)) get unwrapped recursively.""" + + def base(x): + return x + + p_ab = functools.partial(functools.partial(base, a=1), b=2) + p_ba = functools.partial(functools.partial(base, b=2), a=1) + + # Both end up with same final bound state, so hashes should match + assert hash_function_source(p_ab) == hash_function_source(p_ba) + + +def test_closure_values_affect_function_hash(): + def make_arrangement(scale): + def arrangement(input, output): + tile = input.tile((1, scale)) + return tile, output + + return arrangement + + assert hash_function_source(make_arrangement(2)) != hash_function_source( + make_arrangement(4) + ) + + +def test_global_helper_source_affects_function_hash(monkeypatch): + def helper(x): + return x + 1 + + def application(x): + return helper(x) + + original = hash_function_source(application) + + def helper(x): # noqa: F811 + return x + 2 + + monkeypatch.setitem(application.__globals__, "helper", helper) + + assert hash_function_source(application) != original + + +def test_functools_partial_with_args(): + """partial(f, 1, 2) vs partial(f, 3, 4) differ because of bound args.""" + + def base(x, y): + return x + y + + p12 = functools.partial(base, 1, 2) + p34 = functools.partial(base, 3, 4) + + assert hash_function_source(p12) != hash_function_source(p34) + + +def test_fallback_id_prefix_when_source_unavailable(): + """Lambdas defined in REPL/test scope may fail getsource; should still hash.""" + # Create a function whose source cannot be retrieved: a built-in + h = hash_function_source(len) + assert h.startswith("id:") + # Same callable -> same hash + assert hash_function_source(len) == h + + +def test_id_fallback_stable_across_calls(): + h1 = hash_function_source(len) + h2 = hash_function_source(len) + assert h1 == h2 + + +def test_fallback_distinguishes_functools_partial_kwargs(): + """When inspect.getsource fails (REPL, exec, Jupyter), functools.partial + with different bound kwargs must still hash differently. + + Regression test for the silent correctness bug where the fallback path + returned `id:module.qualname@id(func)` -- dropping partial_args entirely, + so `partial(f, ceil_mode=True)` and `partial(f, ceil_mode=False)` would + collide and return the wrong cached handle. + """ + import ninetoothed._cache as cache_mod + + def arrangement(input, output, ceil_mode=False): + return input, output + + p_true = functools.partial(arrangement, ceil_mode=True) + p_false = functools.partial(arrangement, ceil_mode=False) + + # Force the fallback path by patching inspect.getsource to raise. + original_getsource = cache_mod.inspect.getsource + + def _raise_oserror(*args, **kwargs): + raise OSError("simulated: source unavailable") + + cache_mod.inspect.getsource = _raise_oserror + try: + h_true = hash_function_source(p_true) + h_false = hash_function_source(p_false) + finally: + cache_mod.inspect.getsource = original_getsource + + # Both should be in the fallback (`id:`) prefix + assert h_true.startswith("id:") + assert h_false.startswith("id:") + + # Different kwargs must produce different hashes + assert h_true != h_false + + +def test_fallback_distinguishes_functools_partial_args(): + """Same as above, but for positional args (e.g. `partial(f, 1, 2)` vs + `partial(f, 3, 4)`).""" + import ninetoothed._cache as cache_mod + + def base(x, y): + return x + y + + p12 = functools.partial(base, 1, 2) + p34 = functools.partial(base, 3, 4) + + original_getsource = cache_mod.inspect.getsource + + def _raise_oserror(*args, **kwargs): + raise OSError("simulated: source unavailable") + + cache_mod.inspect.getsource = _raise_oserror + try: + h12 = hash_function_source(p12) + h34 = hash_function_source(p34) + finally: + cache_mod.inspect.getsource = original_getsource + + assert h12.startswith("id:") + assert h34.startswith("id:") + assert h12 != h34 + + +# ---------- hash_tensor_signature ---------- + + +class FakeTensor: + """Mimics the ninetoothed.Tensor surface used by hash_tensor_signature.""" + + def __init__( + self, + ndim, + jagged_dim=None, + other=0, + name="t", + shape=None, + dtype=None, + constexpr=False, + value=None, + ): + self.ndim = ndim + self.jagged_dim = jagged_dim + self.other = other + self.name = name + self.shape = tuple(None for _ in range(ndim)) if shape is None else shape + self.dtype = dtype + self.constexpr = constexpr + self.value = value + + +def test_returns_tuple(): + t = FakeTensor(ndim=2) + sig = hash_tensor_signature(t) + assert isinstance(sig, tuple) + + +def test_same_structure_same_signature(): + """Two tensors with the same ndim/jagged_dim/other must hash equal + even if they were constructed with different instance-bound names.""" + t1 = FakeTensor(ndim=2, jagged_dim=None, other=0, name="x_0") + t2 = FakeTensor(ndim=2, jagged_dim=None, other=0, name="x_42") + assert hash_tensor_signature(t1) == hash_tensor_signature(t2) + + +def test_different_ndim_different_signature(): + t1 = FakeTensor(ndim=2) + t2 = FakeTensor(ndim=3) + assert hash_tensor_signature(t1) != hash_tensor_signature(t2) + + +def test_different_jagged_dim_different_signature(): + t1 = FakeTensor(ndim=2, jagged_dim=0) + t2 = FakeTensor(ndim=2, jagged_dim=1) + assert hash_tensor_signature(t1) != hash_tensor_signature(t2) + + +def test_different_other_different_signature(): + t1 = FakeTensor(ndim=2, other=0) + t2 = FakeTensor(ndim=2, other=1) + assert hash_tensor_signature(t1) != hash_tensor_signature(t2) + + +def test_different_static_shape_different_signature(): + t1 = FakeTensor(ndim=2, shape=(None, 128)) + t2 = FakeTensor(ndim=2, shape=(None, 256)) + assert hash_tensor_signature(t1) != hash_tensor_signature(t2) + + +def test_different_dtype_different_signature(): + t1 = FakeTensor(ndim=2, dtype="float16") + t2 = FakeTensor(ndim=2, dtype="float32") + assert hash_tensor_signature(t1) != hash_tensor_signature(t2) + + +def test_different_constexpr_value_different_signature(): + t1 = Tensor(0, constexpr=True, value=16) + t2 = Tensor(0, constexpr=True, value=32) + assert hash_tensor_signature(t1) != hash_tensor_signature(t2) + + +def test_name_attribute_does_not_affect_signature(): + """Critical: Tensor.name is instance-counter-bound and must NOT influence + the cache key, otherwise every fresh Tensor (even structurally identical) + would miss the cache.""" + t_a = FakeTensor(ndim=2, name="input_0") + t_b = FakeTensor(ndim=2, name="input_99999") + assert hash_tensor_signature(t_a) == hash_tensor_signature(t_b) + + +# ---------- hash_value ---------- + + +def test_hash_value_stable(): + assert hash_value(42) == hash_value(42) + assert hash_value("abc") == hash_value("abc") + assert hash_value([1, 2, 3]) == hash_value([1, 2, 3]) + + +def test_hash_value_different_inputs_different_hash(): + assert hash_value(42) != hash_value(43) + assert hash_value("abc") != hash_value("abd") + + +def test_hash_value_returns_hex_string(): + h = hash_value(42) + assert isinstance(h, str) + # SHA256 hex digest is 64 chars + assert len(h) == 64 + int(h, 16) # must be valid hex + + +def test_hash_value_handles_arbitrary_python_objects(): + """repr-based, so any object with a sensible repr works.""" + + class Obj: + def __repr__(self): + return "Obj()" + + a = Obj() + b = Obj() + assert hash_value(a) == hash_value(b) diff --git a/tests/test_max_pool2d.py b/tests/test_max_pool2d.py index 3b6c372..a2cf57d 100644 --- a/tests/test_max_pool2d.py +++ b/tests/test_max_pool2d.py @@ -1,5 +1,3 @@ -import functools - import pytest import torch import torch.nn.functional as F @@ -50,16 +48,16 @@ def _div(x, y, ceil_mode=False): output = torch.empty(n, c, p, q, dtype=input.dtype, device=input.device) - max_pool2d_kernels = { - ceil_mode: ninetoothed.make( - functools.partial(arrangement, ceil_mode=ceil_mode), - application, - (Tensor(4, other=float("-inf")), Tensor(4)), - ) - for ceil_mode in (True, False) - } - - max_pool2d_kernels[ceil_mode](input, output, WINDOW_HEIGHT=r, WINDOW_WIDTH=s) + # Unified Cache (per-process L1 in ninetoothed.make) correctly + # distinguishes ceil_mode via the tensors tuple's repr() hash on + # the trailing `ceil_mode` element (matching test_pad's pattern of + # including non-Tensor arrangement kwargs in tensors). + kernel = ninetoothed.make( + arrangement, + application, + (Tensor(4, other=float("-inf")), Tensor(4), ceil_mode), + ) + kernel(input, output, WINDOW_HEIGHT=r, WINDOW_WIDTH=s) return output diff --git a/tests/test_pad.py b/tests/test_pad.py index 3219184..aec44a4 100644 --- a/tests/test_pad.py +++ b/tests/test_pad.py @@ -25,15 +25,12 @@ def pad(input, pad, mode="constant", value=None): output = torch.full(output_shape, value, dtype=input.dtype, device=input.device) ndim = input.ndim - kernel_config = (ndim, input_slices, output_slices) - kernel_key = str(kernel_config) + tensors = (Tensor(ndim), Tensor(ndim), input_slices, output_slices) - if kernel_key not in _kernel_cache: - tensors = (Tensor(ndim), Tensor(ndim), input_slices, output_slices) - - _kernel_cache[kernel_key] = make(arrangement, application, tensors) - - _kernel_cache[kernel_key](input, output) + # Unified Cache (per-process L1 in ninetoothed.make) handles the + # `input_slices` / `output_slices` slice objects via repr() hashing. + kernel = make(arrangement, application, tensors) + kernel(input, output) return output @@ -63,9 +60,6 @@ def test_pad(shape, pad_, mode, value, dtype, device, atol): assert torch.allclose(output, expected, atol=atol) -_kernel_cache = {} - - def _analyze_pad_config(input, pad, mode): assert mode == "constant", 'Only `"constant"` padding mode is supported.'