diff --git a/charge3net_ft/data.py b/charge3net_ft/data.py index 34ffa8d..544097f 100644 --- a/charge3net_ft/data.py +++ b/charge3net_ft/data.py @@ -10,6 +10,7 @@ opened tables per chunk file so each file is read from disk only once per worker. """ +import collections import json import sys from functools import partial @@ -55,10 +56,20 @@ # --------------------------------------------------------------------------- _SYMBOL_TO_Z = {s: z for z, s in enumerate(ase.data.chemical_symbols)} -# Process-local table cache: keyed by file index, populated on first access. -# Each DataLoader worker process has its own cache, so each chunk file is read -# from disk at most once per worker instead of once per __getitem__ call. -_TABLE_CACHE: dict = {} +# Process-local LRU table cache: keyed by file index, populated on first access. +# Each DataLoader worker has its own cache (workers fork the parent), so each +# chunk file is read from disk at most once per worker per cache cycle. +# +# Bounded LRU because the previous unbounded version OOM-killed jobs 4971293 +# and 4971343 at MaxRSS=35 GB/rank. Per-chunk decompressed pyarrow tables +# weigh ~2 GB (the compressed_charge_density JSON strings inflate 6x from +# disk). With 8 workers x 4 DDP ranks = 32 workers, an unbounded cache grew +# to ~140 GB total in 6 h. +# +# Cap of 5 chunks per worker keeps each worker's cache around 10 GB worst +# case, well under any per-rank memory budget. OrderedDict gives O(1) LRU. +_TABLE_CACHE_MAX_CHUNKS = 5 +_TABLE_CACHE: "collections.OrderedDict[int, object]" = collections.OrderedDict() def _parse_grid_json(json_str: str) -> np.ndarray: @@ -131,7 +142,9 @@ def _build_parquet_index(parquet_dir: Path) -> tuple: index.append((fi, ri)) n_valid = len(index) - print(f"LeMatRhoDataset: {n_valid}/{n_total} valid rows indexed from {len(file_paths)} files") + print( + f"LeMatRhoDataset: {n_valid}/{n_total} valid rows indexed from {len(file_paths)} files" + ) return file_paths, index @@ -186,11 +199,21 @@ def _read_row(self, idx: int) -> dict: """ Read a single row from disk via its index entry. - Uses a process-local cache (_TABLE_CACHE) so each chunk file is - loaded from disk only once per worker, not on every __getitem__ call. + Uses a process-local LRU cache (_TABLE_CACHE) so each chunk file is + loaded from disk at most once per worker per cache cycle. Cache is + capped at _TABLE_CACHE_MAX_CHUNKS entries; on a miss past capacity + the least-recently-used chunk is evicted. Re-access of a present + entry promotes it to most-recent so the running shuffled-access + pattern from RandomSampler doesn't constantly thrash. """ fi, ri = self._index[idx] - if fi not in _TABLE_CACHE: + if fi in _TABLE_CACHE: + # Hit: bump to most-recent and return. + _TABLE_CACHE.move_to_end(fi) + else: + # Miss: evict LRU if at capacity, then read. + if len(_TABLE_CACHE) >= _TABLE_CACHE_MAX_CHUNKS: + _TABLE_CACHE.popitem(last=False) _TABLE_CACHE[fi] = pq.read_table(self._file_paths[fi], columns=_COLUMNS) table = _TABLE_CACHE[fi] row = {} @@ -230,6 +253,7 @@ def build_dataloaders( num_workers: int = 4, seed: int = 42, pin_memory: bool = False, + distributed: bool = False, ) -> tuple: """ Build train, validation, and test DataLoaders. @@ -298,10 +322,27 @@ def build_dataloaders( collate_fn = partial(collate_list_of_dicts, pin_memory=pin_memory) + # DDP path: shard the training set across ranks via DistributedSampler. + # Val/test stay non-distributed (each rank evaluates the whole set; only + # rank 0 reports). This wastes V+T compute but keeps eval simple and + # rank-agnostic. The data is tiny (5%+5% of 65k) so it's fine. + train_sampler = None + if distributed: + from torch.utils.data.distributed import DistributedSampler + + train_sampler = DistributedSampler( + train_subset, + shuffle=True, + seed=seed, + drop_last=True, + ) + train_loader = DataLoader( train_subset, batch_size=batch_size, - shuffle=True, + # shuffle and sampler are mutually exclusive in DataLoader. + shuffle=(train_sampler is None), + sampler=train_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, diff --git a/charge3net_ft/train.py b/charge3net_ft/train.py index c30b277..5004f3f 100644 --- a/charge3net_ft/train.py +++ b/charge3net_ft/train.py @@ -47,9 +47,49 @@ from .model import ChargE3NetWrapper # noqa: E402 +# --------------------------------------------------------------------------- +# Distributed training helpers +# --------------------------------------------------------------------------- +def _is_ddp() -> bool: + """True if SLURM/torchrun has set up multi-process training.""" + return int(os.environ.get("WORLD_SIZE", "1")) > 1 + + +def _setup_ddp() -> tuple[int, int, int]: + """Initialize the process group and return (rank, local_rank, world_size). + + No-op (returns 0, 0, 1) if we're not in a distributed environment. + + The submit script is expected to export the standard torch env vars from + SLURM: + WORLD_SIZE = $SLURM_NTASKS + RANK = $SLURM_PROCID + LOCAL_RANK = $SLURM_LOCALID + MASTER_ADDR = $(scontrol show hostname $SLURM_NODELIST | head -1) + MASTER_PORT = some unused port (e.g. 29500) + """ + if not _is_ddp(): + return 0, 0, 1 + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + # nccl works on AMD ROCm because PyTorch routes it through RCCL. + torch.distributed.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + return rank, local_rank, world_size + + +def _is_main(rank: int) -> bool: + """True on rank 0; used to gate prints, wandb, and checkpoint saves.""" + return rank == 0 + + def _probe_mask(targets: torch.Tensor, num_probes: torch.Tensor) -> torch.Tensor: """Boolean mask [B, max_probes], True for real probe points (not padding).""" - return torch.arange(targets.shape[1], device=targets.device)[None] < num_probes[:, None] + return ( + torch.arange(targets.shape[1], device=targets.device)[None] + < num_probes[:, None] + ) def compute_nmape( @@ -108,8 +148,16 @@ def compute_nrmse( return (rmse / (mean_abs + 1e-10) * 100.0).mean() -def train_one_epoch(model, train_loader, optimizer, scheduler, device, global_step, - log_every=50, use_wandb=False): +def train_one_epoch( + model, + train_loader, + optimizer, + scheduler, + device, + global_step, + log_every=50, + use_wandb=False, +): """Run one training epoch, return (average loss, updated global_step).""" model.train() total_loss = 0.0 @@ -135,7 +183,7 @@ def train_one_epoch(model, train_loader, optimizer, scheduler, device, global_st if (i + 1) % log_every == 0: lr = optimizer.param_groups[0]["lr"] - print(f" step {i+1}: loss={loss.item():.6f} lr={lr:.2e}") + print(f" step {i + 1}: loss={loss.item():.6f} lr={lr:.2e}") if use_wandb: wandb.log({"train/loss_step": loss.item(), "lr": lr}, step=global_step) @@ -177,12 +225,22 @@ def validate(model, loader, device): } +def _unwrap(model): + """Return the underlying ChargE3NetWrapper regardless of DDP wrapping. + + DistributedDataParallel wraps the user model in a ``.module`` attribute; + state_dict() and load_state_dict() should always target the inner model + so checkpoints are interchangeable between single-GPU and DDP runs. + """ + return model.module if hasattr(model, "module") else model + + def save_checkpoint(model, optimizer, scheduler, epoch, best_nmape, global_step, path): - """Save training checkpoint.""" + """Save training checkpoint (rank 0 should be the only caller in DDP).""" torch.save( { "epoch": epoch, - "model": model.model.state_dict(), + "model": _unwrap(model).model.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "best_nmape": best_nmape, @@ -195,7 +253,7 @@ def save_checkpoint(model, optimizer, scheduler, epoch, best_nmape, global_step, def load_checkpoint(path, model, optimizer, scheduler, device): """Load training checkpoint, return (start_epoch, best_nmape, global_step).""" ckpt = torch.load(path, map_location=device, weights_only=False) - model.model.load_state_dict(ckpt["model"]) + _unwrap(model).model.load_state_dict(ckpt["model"]) optimizer.load_state_dict(ckpt["optimizer"]) scheduler.load_state_dict(ckpt["scheduler"]) start_epoch = ckpt["epoch"] + 1 @@ -222,19 +280,35 @@ def main(): "Defaults to $LEMATRHO_DATA_DIR env var." ), ) - parser.add_argument("--ckpt-path", type=str, default=None, help="Pre-trained checkpoint (.pt)") - parser.add_argument("--save-dir", type=str, default="./checkpoints", help="Save directory") + parser.add_argument( + "--ckpt-path", type=str, default=None, help="Pre-trained checkpoint (.pt)" + ) + parser.add_argument( + "--save-dir", type=str, default="./checkpoints", help="Save directory" + ) parser.add_argument("--cutoff", type=float, default=4.0, help="Neighbor cutoff (A)") - parser.add_argument("--train-probes", type=int, default=200, help="Probes per sample (train)") - parser.add_argument("--val-probes", type=int, default=1000, help="Probes per sample (val/test)") + parser.add_argument( + "--train-probes", type=int, default=200, help="Probes per sample (train)" + ) + parser.add_argument( + "--val-probes", type=int, default=1000, help="Probes per sample (val/test)" + ) parser.add_argument("--batch-size", type=int, default=4, help="Batch size") parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate") parser.add_argument("--epochs", type=int, default=50, help="Number of epochs") - parser.add_argument("--val-frac", type=float, default=0.05, - help="Validation fraction. Do not change after first run.") - parser.add_argument("--test-frac", type=float, default=0.05, - help="Test fraction (held out, evaluated once at end). " - "Do not change after first run.") + parser.add_argument( + "--val-frac", + type=float, + default=0.05, + help="Validation fraction. Do not change after first run.", + ) + parser.add_argument( + "--test-frac", + type=float, + default=0.05, + help="Test fraction (held out, evaluated once at end). " + "Do not change after first run.", + ) parser.add_argument("--num-workers", type=int, default=4, help="DataLoader workers") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--log-every", type=int, default=50, help="Log every N steps") @@ -252,14 +326,22 @@ def main(): default=None, help="Force device (cpu, cuda, mps). Auto-detect if not set.", ) - parser.add_argument("--resume-from", type=str, default=None, - help="Path to training checkpoint (latest.pt) to resume from") + parser.add_argument( + "--resume-from", + type=str, + default=None, + help="Path to training checkpoint (latest.pt) to resume from", + ) parser.add_argument("--wandb-project", type=str, default="lemat-rho-charge3net") parser.add_argument("--wandb-entity", type=str, default="dtts") parser.add_argument("--no-wandb", action="store_true", help="Disable W&B logging") - parser.add_argument("--wandb-mode", type=str, default="online", - choices=["online", "offline", "disabled"], - help="W&B mode (use 'offline' on air-gapped clusters)") + parser.add_argument( + "--wandb-mode", + type=str, + default="online", + choices=["online", "offline", "disabled"], + help="W&B mode (use 'offline' on air-gapped clusters)", + ) args = parser.parse_args() if args.parquet_dir is None: @@ -274,28 +356,48 @@ def main(): if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) + # DDP setup (no-op when WORLD_SIZE=1). Must happen before device + # selection because each rank pins itself to its own GPU via local_rank. + rank, local_rank, world_size = _setup_ddp() + is_main = _is_main(rank) + # Device if args.device: device = torch.device(args.device) + elif _is_ddp(): + device = torch.device(f"cuda:{local_rank}") elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") - print(f"Using device: {device}") - - # W&B - use_wandb = not args.no_wandb and not args.smoke_test + if is_main: + print(f"Using device: {device}; world_size={world_size}") + + # W&B (rank 0 only). Soft-fail: if init times out (e.g. compute node + # can't reach api.wandb.ai through the cluster proxy), degrade to + # disabled mode and keep training. Used to be fatal — caused the + # 1h47m job 4969727 timeout-then-crash on Adastra. + use_wandb = (not args.no_wandb and not args.smoke_test) and is_main if use_wandb: - wandb.init( - project=args.wandb_project, - entity=args.wandb_entity, - config=vars(args), - settings=wandb.Settings(init_timeout=300), - mode=args.wandb_mode, - ) + try: + wandb.init( + project=args.wandb_project, + entity=args.wandb_entity, + config=vars(args), + settings=wandb.Settings(init_timeout=300), + mode=args.wandb_mode, + ) + except Exception as e: # noqa: BLE001 — really do want broad here + print( + f"WARNING: wandb.init failed ({type(e).__name__}: {e}); " + "continuing with wandb disabled. Training output is still " + "saved to checkpoints + stdout." + ) + use_wandb = False # Data - print("Building dataloaders...") + if is_main: + print("Building dataloaders...") train_loader, val_loader, test_loader = build_dataloaders( parquet_dir=args.parquet_dir, cutoff=args.cutoff, @@ -306,26 +408,40 @@ def main(): test_frac=args.test_frac, num_workers=args.num_workers, seed=args.seed, + distributed=_is_ddp(), ) - print( - f"Train: {len(train_loader.dataset)} samples, " - f"Val: {len(val_loader.dataset)} samples, " - f"Test: {len(test_loader.dataset)} samples" - ) + if is_main: + print( + f"Train: {len(train_loader.dataset)} samples, " + f"Val: {len(val_loader.dataset)} samples, " + f"Test: {len(test_loader.dataset)} samples" + ) - # Model - print("Initializing ChargE3Net...") + # Model. Loaded on every rank (each gets its own copy of the weights); + # DDP will sync gradients across ranks at backward. + if is_main: + print("Initializing ChargE3Net...") model = ChargE3NetWrapper(ckpt_path=args.ckpt_path, cutoff=args.cutoff) model = model.to(device) - n_params = sum(p.numel() for p in model.parameters()) - print(f"Model parameters: {n_params:,}") + if _is_ddp(): + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[local_rank], output_device=local_rank + ) + n_params = sum( + p.numel() for p in (model.module if _is_ddp() else model).parameters() + ) + if is_main: + print(f"Model parameters: {n_params:,}") # Smoke test: just run one forward pass if args.smoke_test: print("\n--- Smoke test ---") model.eval() batch = next(iter(train_loader)) - batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} + batch = { + k: v.to(device) if isinstance(v, torch.Tensor) else v + for k, v in batch.items() + } print(f"Batch keys: {list(batch.keys())}") for k, v in batch.items(): if isinstance(v, torch.Tensor): @@ -376,13 +492,15 @@ def main(): f"NMAPE={nmape.item():.2f}% RMSE={rmse.item():.4f} NRMSE={nrmse.item():.2f}%" ) if use_wandb: - wandb.log({ - "overfit/L1": loss.item(), - "overfit/NMAPE": nmape.item(), - "overfit/RMSE": rmse.item(), - "overfit/NRMSE": nrmse.item(), - "epoch": epoch, - }) + wandb.log( + { + "overfit/L1": loss.item(), + "overfit/NMAPE": nmape.item(), + "overfit/RMSE": rmse.item(), + "overfit/NRMSE": nrmse.item(), + "epoch": epoch, + } + ) print("\nOverfit test complete.") if use_wandb: @@ -405,53 +523,86 @@ def main(): if args.resume_from: start_epoch, best_nmape, global_step = load_checkpoint( - args.resume_from, model, optimizer, scheduler, device, + args.resume_from, + model, + optimizer, + scheduler, + device, ) - print(f"\nStarting training from epoch {start_epoch + 1} to {args.epochs}...") + if is_main: + print(f"\nStarting training from epoch {start_epoch + 1} to {args.epochs}...") for epoch in range(start_epoch, args.epochs): + # DDP requires set_epoch on the sampler each epoch for proper shuffling. + if _is_ddp() and hasattr(train_loader.sampler, "set_epoch"): + train_loader.sampler.set_epoch(epoch) t0 = time.time() train_loss, global_step = train_one_epoch( - model, train_loader, optimizer, scheduler, device, global_step, - log_every=args.log_every, use_wandb=use_wandb, + model, + train_loader, + optimizer, + scheduler, + device, + global_step, + log_every=args.log_every, + use_wandb=use_wandb, ) val = validate(model, val_loader, device) elapsed = time.time() - t0 - print( - f"Epoch {epoch+1}/{args.epochs} " - f"train_L1={train_loss:.6f} " - f"val_L1={val['L1']:.6f} " - f"val_NMAPE={val['NMAPE']:.2f}% " - f"val_RMSE={val['RMSE']:.4f} " - f"val_NRMSE={val['NRMSE']:.2f}% " - f"time={elapsed:.0f}s" - ) + if is_main: + print( + f"Epoch {epoch + 1}/{args.epochs} " + f"train_L1={train_loss:.6f} " + f"val_L1={val['L1']:.6f} " + f"val_NMAPE={val['NMAPE']:.2f}% " + f"val_RMSE={val['RMSE']:.4f} " + f"val_NRMSE={val['NRMSE']:.2f}% " + f"time={elapsed:.0f}s" + ) if use_wandb: - wandb.log({ - "train/L1": train_loss, - "val/L1": val["L1"], - "val/NMAPE": val["NMAPE"], - "val/RMSE": val["RMSE"], - "val/NRMSE": val["NRMSE"], - "epoch": epoch + 1, - }, step=global_step) - - # Save best checkpoint (selected on val NMAPE) - if val["NMAPE"] < best_nmape: + wandb.log( + { + "train/L1": train_loss, + "val/L1": val["L1"], + "val/NMAPE": val["NMAPE"], + "val/RMSE": val["RMSE"], + "val/NRMSE": val["NRMSE"], + "epoch": epoch + 1, + }, + step=global_step, + ) + + # Save best checkpoint (selected on val NMAPE). Only rank 0 writes. + if is_main and val["NMAPE"] < best_nmape: best_nmape = val["NMAPE"] save_checkpoint( - model, optimizer, scheduler, epoch, best_nmape, global_step, + model, + optimizer, + scheduler, + epoch, + best_nmape, + global_step, save_dir / "best.pt", ) print(f" -> New best val NMAPE: {best_nmape:.2f}%") - # Save latest checkpoint every epoch (for SLURM resumption) - save_checkpoint( - model, optimizer, scheduler, epoch, best_nmape, global_step, - save_dir / "latest.pt", - ) + # Save latest checkpoint every epoch (for SLURM resumption). + if is_main: + save_checkpoint( + model, + optimizer, + scheduler, + epoch, + best_nmape, + global_step, + save_dir / "latest.pt", + ) + + # Keep ranks in lockstep so a slow saver doesn't get lapped. + if _is_ddp(): + torch.distributed.barrier() # ----------------------------------------------------------------------- # Test set evaluation — run once at the end using the best checkpoint. @@ -471,17 +622,22 @@ def main(): f"RMSE={test['RMSE']:.4f} NRMSE={test['NRMSE']:.2f}%" ) if use_wandb: - wandb.log({ - "test/L1": test["L1"], - "test/NMAPE": test["NMAPE"], - "test/RMSE": test["RMSE"], - "test/NRMSE": test["NRMSE"], - }) - - print(f"\nTraining complete. Best val NMAPE: {best_nmape:.2f}%") - print(f"Checkpoints saved to {save_dir}") + wandb.log( + { + "test/L1": test["L1"], + "test/NMAPE": test["NMAPE"], + "test/RMSE": test["RMSE"], + "test/NRMSE": test["NRMSE"], + } + ) + + if is_main: + print(f"\nTraining complete. Best val NMAPE: {best_nmape:.2f}%") + print(f"Checkpoints saved to {save_dir}") if use_wandb: wandb.finish() + if _is_ddp(): + torch.distributed.destroy_process_group() if __name__ == "__main__": diff --git a/deepdft_ft/__init__.py b/deepdft_ft/__init__.py new file mode 100644 index 0000000..da6fdd4 --- /dev/null +++ b/deepdft_ft/__init__.py @@ -0,0 +1,5 @@ +"""DeepDFT (peterbjorgensen/DeepDFT) fine-tuning glue for LeMat-Rho. + +Mirrors ``charge3net_ft/`` in structure: the data loader reuses ``charge3net_ft``'s +parquet helpers and adapts the per-sample shape to DeepDFT's dict contract. +""" diff --git a/deepdft_ft/data.py b/deepdft_ft/data.py new file mode 100644 index 0000000..03acbb2 --- /dev/null +++ b/deepdft_ft/data.py @@ -0,0 +1,139 @@ +"""LeMat-Rho → DeepDFT data adapter. + +DeepDFT's ``runner.py`` expects a ``torch.utils.data.Dataset`` that yields +per-sample dicts of the form:: + + { + "density": np.ndarray (Nx, Ny, Nz), + "atoms": ase.Atoms, + "origin": np.ndarray (3,), + "grid_position": np.ndarray (Nx, Ny, Nz, 3), + "metadata": {"filename": str, ...}, + } + +That dict is fed into DeepDFT's ``CollateFuncRandomSample`` which samples +random probe points, builds the atom/probe graph via asap3, and pads the +batch. The only thing we provide is a path from a directory of LeMat-Rho +parquet chunks to that dict shape. + +The parquet schema, the index building, and the row → (atoms, density, origin) +conversion live in ``charge3net_ft.data`` and are reused verbatim. Keeping a +single source of truth for the input pipeline means a future Bader/extra-column +addition only needs one regression test. +""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional + +import numpy as np +import pyarrow.parquet as pq +from torch.utils.data import Dataset + +from charge3net_ft.data import ( + _COLUMNS, + _build_parquet_index, + _row_to_atoms_and_density, +) + + +# Per-worker cache, separate from charge3net_ft's so the two pipelines don't +# step on each other when running side by side in the same process. +_DEEPDFT_TABLE_CACHE: dict = {} + + +def _calculate_grid_pos(density: np.ndarray, origin: np.ndarray, cell) -> np.ndarray: + """Cartesian probe positions for an (Nx, Ny, Nz) density grid. + + Same formula DeepDFT uses internally (see DeepDFT/dataset.py:_calculate_grid_pos). + Kept here so we don't need DeepDFT importable at test time. + + Parameters + ---------- + density : np.ndarray of shape (Nx, Ny, Nz) + Used only for its shape. + origin : np.ndarray of shape (3,) + Cell-frame origin in Cartesian coordinates. + cell : ASE Cell or 3x3 array + Lattice vectors as rows. + + Returns + ------- + grid_pos : np.ndarray of shape (Nx, Ny, Nz, 3) + Cartesian coordinates of every grid point. + """ + ngridpts = np.array(density.shape) + grid_pos = np.meshgrid( + np.arange(ngridpts[0]) / density.shape[0], + np.arange(ngridpts[1]) / density.shape[1], + np.arange(ngridpts[2]) / density.shape[2], + indexing="ij", + ) + grid_pos = np.stack(grid_pos, 3) + grid_pos = np.dot(grid_pos, np.asarray(cell)) + grid_pos = grid_pos + origin + return grid_pos + + +class LeMatRhoDeepDFTDataset(Dataset): + """Iterate LeMat-Rho parquet chunks as DeepDFT-shaped sample dicts. + + Parameters + ---------- + parquet_dir : str or Path + Directory containing ``chunk_*.parquet`` files. + _shared_index : tuple, optional + Internal: pre-built (file_paths, index) tuple shared between + train/val splits to avoid scanning files twice. + """ + + def __init__( + self, + parquet_dir: str | Path | None = None, + _shared_index: Optional[tuple] = None, + ): + if _shared_index is not None: + self._file_paths, self._index = _shared_index + else: + if parquet_dir is None: + raise ValueError("Must provide parquet_dir or _shared_index") + self._file_paths, self._index = _build_parquet_index(Path(parquet_dir)) + + def __len__(self) -> int: + return len(self._index) + + def _read_row(self, idx: int) -> dict: + """Lazy per-worker chunk caching, mirrors charge3net_ft.data. + + Cache is keyed by the absolute parquet path (not the integer ``fi``) + so multiple ``LeMatRhoDeepDFTDataset`` instances pointing at different + directories don't collide on ``fi=0``. + """ + fi, ri = self._index[idx] + key = str(self._file_paths[fi].resolve()) + if key not in _DEEPDFT_TABLE_CACHE: + _DEEPDFT_TABLE_CACHE[key] = pq.read_table( + self._file_paths[fi], columns=_COLUMNS + ) + table = _DEEPDFT_TABLE_CACHE[key] + return {col: table.column(col)[ri].as_py() for col in _COLUMNS} + + def __getitem__(self, idx: int) -> dict: + row = self._read_row(idx) + atoms, density, origin = _row_to_atoms_and_density(row) + grid_pos = _calculate_grid_pos(density, origin, atoms.get_cell()) + + # Index-derived filename so DeepDFT logs stay distinguishable across + # samples. Format mirrors the tar member names DeepDFT normally sees. + fi, ri = self._index[idx] + chunk_stem = Path(self._file_paths[fi]).stem # e.g. "chunk_000017" + filename = f"{chunk_stem}_row{ri:06d}.parquet" + + return { + "density": density, + "atoms": atoms, + "origin": origin, + "grid_position": grid_pos, + "metadata": {"filename": filename}, + } diff --git a/deepdft_ft/runner.py b/deepdft_ft/runner.py new file mode 100644 index 0000000..ad49a35 --- /dev/null +++ b/deepdft_ft/runner.py @@ -0,0 +1,565 @@ +"""DeepDFT training runner — vendored from peterbjorgensen/DeepDFT@main. + +Vendored rather than monkey-patched because the DDP integration touches +many points throughout `main()` (dataset construction, model wrap, +sampler, checkpoint save, logging gates). Keeping the patched copy here +makes the delta auditable and the code testable. + +Diff vs upstream: +- Adds DDP setup via `_setup_ddp`/`_is_main` helpers (mirrors the pattern + used in `charge3net_ft/train.py`). DDP activates iff `WORLD_SIZE>1`. +- Detects parquet directories and uses `LeMatRhoDeepDFTDataset` instead + of `dataset.DensityData`. Other arg formats are passed through to + upstream unchanged so the runner still works on the original tar/dir + datasets. +- `RandomSampler` swapped for `DistributedSampler` when DDP active. +- Model wrapped in `DistributedDataParallel`; checkpoint save/load unwraps + via `_unwrap`. +- Logging + checkpoint writes gated on rank 0. +""" + +from __future__ import annotations + +import os +import sys +import json +import argparse +import math +import logging +import itertools +import timeit +from pathlib import Path + +import numpy as np +import torch +import torch.utils.data +from torch.utils.data.distributed import DistributedSampler + +torch.set_num_threads(1) # Try to avoid thread overload on cluster + +# --------------------------------------------------------------------------- +# Make the DeepDFT sibling repo importable. Expected layout (mirrors +# how charge3net is set up): +# / <-- LeMat-Rho +# /../DeepDFT/ <-- AIforGreatGood/DeepDFT clone +# --------------------------------------------------------------------------- +_DEEPDFT_ROOT = Path(__file__).resolve().parent.parent.parent / "DeepDFT" +if not _DEEPDFT_ROOT.exists(): + raise RuntimeError( + f"DeepDFT repo not found at {_DEEPDFT_ROOT}.\n" + "Clone it with: git clone https://github.com/peterbjorgensen/DeepDFT " + f"{_DEEPDFT_ROOT}" + ) +if str(_DEEPDFT_ROOT) not in sys.path: + sys.path.insert(0, str(_DEEPDFT_ROOT)) + +# --------------------------------------------------------------------------- +# Stub `asap3` if it isn't available. Building asap3 from source requires +# Python.h which isn't installed on Adastra (and getting it would need +# admin). Upstream DeepDFT supports an ASE-based fallback via +# `AseNeigborListWrapper`; we expose the same interface from `asap3.FullNeighborList` +# so the upstream `import asap3 ; asap3.FullNeighborList(...)` calls work. +# --------------------------------------------------------------------------- +try: + import asap3 # noqa: F401 +except ImportError: + import types + + import ase.neighborlist + import numpy as np + + _asap3_stub = types.ModuleType("asap3") + + class _AseFullNeighborList: + """Drop-in `asap3.FullNeighborList` replacement using ASE primitives. + + Behaviourally equivalent for DeepDFT's use case: ``get_neighbors(i, cutoff)`` + returns ``(indices, rel_positions, dist2)`` arrays. Much slower than real + asap3 but works without C++ headers. + """ + + def __init__(self, cutoff, atoms): + self._cutoff = cutoff + self._positions = atoms.get_positions() + self._cell = np.asarray(atoms.get_cell()) + nl = ase.neighborlist.NewPrimitiveNeighborList( + cutoff, skin=0.0, self_interaction=False, bothways=True + ) + nl.build(atoms.get_pbc(), atoms.get_cell(), atoms.get_positions()) + self._nl = nl + + def get_neighbors(self, i, cutoff): + assert cutoff == self._cutoff, ( + "cutoff must match the one used at FullNeighborList init" + ) + indices, offsets = self._nl.get_neighbors(i) + rel_positions = ( + self._positions[indices] + offsets @ self._cell - self._positions[i] + ) + dist2 = (rel_positions**2).sum(axis=1) + return indices, rel_positions, dist2 + + _asap3_stub.FullNeighborList = _AseFullNeighborList + sys.modules["asap3"] = _asap3_stub + +import densitymodel # noqa: E402 (upstream module) +import dataset # noqa: E402 (upstream module) + +from deepdft_ft.data import LeMatRhoDeepDFTDataset # noqa: E402 + + +# --------------------------------------------------------------------------- +# Distributed-training helpers (same pattern as charge3net_ft/train.py). +# --------------------------------------------------------------------------- +def _is_ddp() -> bool: + return int(os.environ.get("WORLD_SIZE", "1")) > 1 + + +def _setup_ddp() -> tuple[int, int, int]: + """Returns (rank, local_rank, world_size). No-op when WORLD_SIZE=1.""" + if not _is_ddp(): + return 0, 0, 1 + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + # nccl routes through RCCL on AMD ROCm builds. + torch.distributed.init_process_group(backend="nccl") + torch.cuda.set_device(local_rank) + return rank, local_rank, world_size + + +def _is_main(rank: int) -> bool: + return rank == 0 + + +def _unwrap(model: torch.nn.Module) -> torch.nn.Module: + """Strip DistributedDataParallel for state_dict access.""" + return model.module if hasattr(model, "module") else model + + +def _is_parquet_dir(path: str | Path) -> bool: + """LeMat-Rho parquet dirs contain ``chunk_*.parquet``; tar/cube paths don't.""" + p = Path(path) + return p.is_dir() and any(p.glob("chunk_*.parquet")) + + +def get_arguments(arg_list=None): + parser = argparse.ArgumentParser( + description="Train graph convolution network", fromfile_prefix_chars="+" + ) + parser.add_argument( + "--load_model", + type=str, + default=None, + help="Load model parameters from previous run", + ) + parser.add_argument( + "--cutoff", + type=float, + default=5.0, + help="Atomic interaction cutoff distance [Å]", + ) + parser.add_argument( + "--split_file", + type=str, + default=None, + help="Train/test/validation split file json", + ) + parser.add_argument( + "--num_interactions", + type=int, + default=3, + help="Number of interaction layers used", + ) + parser.add_argument( + "--node_size", type=int, default=64, help="Size of hidden node states" + ) + parser.add_argument( + "--output_dir", + type=str, + default="runs/model_output", + help="Path to output directory", + ) + parser.add_argument( + "--dataset", + type=str, + default="data/qm9.db", + help="Path to ASE database", + ) + parser.add_argument( + "--max_steps", + type=int, + default=int(1e6), + help="Maximum number of optimisation steps", + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Set which device to use for training e.g. 'cuda' or 'cpu'", + ) + + parser.add_argument( + "--use_painn_model", + action="store_true", + help="Enable equivariant message passing model (PaiNN)", + ) + + parser.add_argument( + "--ignore_pbc", + action="store_true", + help="If flag is given, disable periodic boundary conditions (force to False) in atoms data", + ) + + parser.add_argument( + "--force_pbc", + action="store_true", + help="If flag is given, force periodic boundary conditions to True in atoms data", + ) + + return parser.parse_args(arg_list) + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self, name, fmt=":f"): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" + return fmtstr.format(**self.__dict__) + + +def split_data(dataset, args): + # Load or generate splits + if args.split_file: + with open(args.split_file, "r") as fp: + splits = json.load(fp) + else: + datalen = len(dataset) + num_validation = int(math.ceil(datalen * 0.05)) + indices = np.random.permutation(len(dataset)) + splits = { + "train": indices[num_validation:].tolist(), + "validation": indices[:num_validation].tolist(), + } + + # Save split file + with open(os.path.join(args.output_dir, "datasplits.json"), "w") as f: + json.dump(splits, f) + + # Split the dataset + datasplits = {} + for key, indices in splits.items(): + datasplits[key] = torch.utils.data.Subset(dataset, indices) + return datasplits + + +def eval_model(model, dataloader, device): + with torch.no_grad(): + running_ae = torch.tensor(0.0, device=device) + running_se = torch.tensor(0.0, device=device) + running_count = torch.tensor(0.0, device=device) + for batch in dataloader: + device_batch = { + k: v.to(device=device, non_blocking=True) for k, v in batch.items() + } + outputs = model(device_batch) + targets = device_batch["probe_target"] + + running_ae += torch.sum(torch.abs(targets - outputs)) + running_se += torch.sum(torch.square(targets - outputs)) + running_count += torch.sum(device_batch["num_probes"]) + + mae = (running_ae / running_count).item() + rmse = (torch.sqrt(running_se / running_count)).item() + + return mae, rmse + + +def get_normalization(dataset, per_atom=True): + try: + num_targets = len(dataset.transformer.targets) + except AttributeError: + num_targets = 1 + x_sum = torch.zeros(num_targets) + x_2 = torch.zeros(num_targets) + num_objects = 0 + for sample in dataset: + x = sample["targets"] + if per_atom: + x = x / sample["num_nodes"] + x_sum += x + x_2 += x**2.0 + num_objects += 1 + # Var(X) = E[X^2] - E[X]^2 + x_mean = x_sum / num_objects + x_var = x_2 / num_objects - x_mean**2.0 + + return x_mean, torch.sqrt(x_var) + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def main(): + args = get_arguments() + + # DDP setup (no-op when WORLD_SIZE=1). Must precede device + dataset + # construction; each rank pins itself to its own GCD via local_rank. + rank, local_rank, world_size = _setup_ddp() + is_main = _is_main(rank) + + # Override device for DDP runs. + if _is_ddp(): + args.device = f"cuda:{local_rank}" + + # Setup logging + os.makedirs(args.output_dir, exist_ok=True) + logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s [%(levelname)-5.5s] %(message)s", + handlers=[ + logging.FileHandler( + os.path.join(args.output_dir, "printlog.txt"), mode="w" + ), + logging.StreamHandler(), + ], + ) + + # Save command line args + with open(os.path.join(args.output_dir, "commandline_args.txt"), "w") as f: + f.write("\n".join(sys.argv[1:])) + # Save parsed command line arguments + with open(os.path.join(args.output_dir, "arguments.json"), "w") as f: + json.dump(vars(args), f) + + # Setup dataset and loader. If args.dataset points at a directory of + # LeMat-Rho chunk_*.parquet files, use our adapter; otherwise fall + # through to upstream's tar/cube/dir loader unchanged. + if _is_parquet_dir(args.dataset): + if is_main: + logging.info("loading LeMat-Rho parquet dir %s", args.dataset) + densitydata = LeMatRhoDeepDFTDataset(parquet_dir=args.dataset) + else: + if args.dataset.endswith(".txt"): + # Text file contains list of datafiles + with open(args.dataset, "r") as datasetfiles: + filelist = [ + os.path.join(os.path.dirname(args.dataset), line.strip("\n")) + for line in datasetfiles + ] + else: + filelist = [args.dataset] + if is_main: + logging.info("loading data %s", args.dataset) + densitydata = torch.utils.data.ConcatDataset( + [dataset.DensityData(path) for path in filelist] + ) + + # Split data into train and validation sets + datasplits = split_data(densitydata, args) + # Pool_size and num_workers downsized for LeMat-Rho cells whose r2SCAN + # CHGCARs are larger than the QM9/MP grids upstream was tuned for: the + # rotating pool keeps full grids in RAM per worker (pool_size * + # num_workers concurrent structures), and a handful of 200-300^3 cells + # is enough to OOM the 64 GB job at the upstream 20*4 = 80. + datasplits["train"] = dataset.RotatingPoolData(datasplits["train"], 5) + + if args.ignore_pbc and args.force_pbc: + raise ValueError( + "ignore_pbc and force_pbc are mutually exclusive and can't both be set at the same time" + ) + elif args.ignore_pbc: + set_pbc = False + elif args.force_pbc: + set_pbc = True + else: + set_pbc = None + + # Setup loaders. With DDP, the train sampler shards data across ranks + # so each rank sees a disjoint subset per epoch. Val stays + # non-distributed and only rank 0 actually uses it. + if _is_ddp(): + train_sampler = DistributedSampler( + datasplits["train"], shuffle=True, drop_last=True + ) + else: + train_sampler = torch.utils.data.RandomSampler(datasplits["train"]) + train_loader = torch.utils.data.DataLoader( + datasplits["train"], + 2, + # See RotatingPoolData(...5) above; num_workers compounds the RAM + # footprint of the rotating pool. 2 workers x 5 pool = 10 grids in + # RAM peak, well below 64 GB for the LeMat-Rho size distribution. + num_workers=2, + sampler=train_sampler, + collate_fn=dataset.CollateFuncRandomSample( + args.cutoff, 1000, pin_memory=False, set_pbc_to=set_pbc + ), + ) + val_loader = torch.utils.data.DataLoader( + datasplits["validation"], + 2, + collate_fn=dataset.CollateFuncRandomSample( + args.cutoff, 5000, pin_memory=False, set_pbc_to=set_pbc + ), + num_workers=0, + ) + # Upstream materialised the full val_loader into a list at startup for + # speed ("Preloading validation batch"). Their NMC/QM9/ethyleneCarbonate + # val sets are ~100 materials so that's cheap. Ours is ~3.3 k materials + # x 5 000 probes/material -> ~150 GB if eagerly preloaded, which OOM-killed + # job 4971720. Leave val_loader as a streaming DataLoader instead; the + # data-loading overhead per val pass is negligible compared to DDP + # gradient sync (when DDP is enabled). Hyperparameters are unchanged. + + # Initialise model + device = torch.device(args.device) + if args.use_painn_model: + net = densitymodel.PainnDensityModel( + args.num_interactions, args.node_size, args.cutoff + ) + else: + net = densitymodel.DensityModel( + args.num_interactions, args.node_size, args.cutoff + ) + if is_main: + logging.debug("model has %d parameters", count_parameters(net)) + net = net.to(device) + if _is_ddp(): + net = torch.nn.parallel.DistributedDataParallel( + net, device_ids=[local_rank], output_device=local_rank + ) + + # Setup optimizer + optimizer = torch.optim.Adam(net.parameters(), lr=0.0001) + criterion = torch.nn.MSELoss() + scheduler_fn = lambda step: 0.96 ** (step / 100000) # noqa: E731 (vendored) + scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_fn) + + log_interval = 5000 + running_loss = torch.tensor(0.0, device=device) + running_loss_count = torch.tensor(0, device=device) + best_val_mae = np.inf + step = 0 + # Restore checkpoint + if args.load_model: + state_dict = torch.load(args.load_model, map_location=device) + _unwrap(net).load_state_dict(state_dict["model"]) + step = state_dict["step"] + best_val_mae = state_dict["best_val_mae"] + optimizer.load_state_dict(state_dict["optimizer"]) + scheduler.load_state_dict(state_dict["scheduler"]) + + if is_main: + logging.info("start training") + + data_timer = AverageMeter("data_timer") + transfer_timer = AverageMeter("transfer_timer") + train_timer = AverageMeter("train_timer") + eval_timer = AverageMeter("eval_time") + + endtime = timeit.default_timer() + for _ in itertools.count(): + for batch_host in train_loader: + data_timer.update(timeit.default_timer() - endtime) + tstart = timeit.default_timer() + # Transfer to 'device' + batch = { + k: v.to(device=device, non_blocking=True) + for (k, v) in batch_host.items() + } + transfer_timer.update(timeit.default_timer() - tstart) + + tstart = timeit.default_timer() + # Reset gradient + optimizer.zero_grad() + + # Forward, backward and optimize + outputs = net(batch) + loss = criterion(outputs, batch["probe_target"]) + loss.backward() + optimizer.step() + + with torch.no_grad(): + running_loss += ( + loss + * batch["probe_target"].shape[0] + * batch["probe_target"].shape[1] + ) + running_loss_count += torch.sum(batch["num_probes"]) + + train_timer.update(timeit.default_timer() - tstart) + + # print(step, loss_value) + # Validate and save model + if (step % log_interval == 0) or ((step + 1) == args.max_steps): + tstart = timeit.default_timer() + with torch.no_grad(): + train_loss = (running_loss / running_loss_count).item() + running_loss = running_loss_count = 0 + + val_mae, val_rmse = eval_model(net, val_loader, device) + + if is_main: + logging.info( + "step=%d, val_mae=%g, val_rmse=%g, sqrt(train_loss)=%g", + step, + val_mae, + val_rmse, + math.sqrt(train_loss), + ) + + # Save checkpoint (rank 0 only). _unwrap so the state_dict + # is interchangeable between single-GPU and DDP runs. + if is_main and val_mae < best_val_mae: + best_val_mae = val_mae + torch.save( + { + "model": _unwrap(net).state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "step": step, + "best_val_mae": best_val_mae, + }, + os.path.join(args.output_dir, "best_model.pth"), + ) + + eval_timer.update(timeit.default_timer() - tstart) + logging.debug( + "%s %s %s %s" + % (data_timer, transfer_timer, train_timer, eval_timer) + ) + step += 1 + + scheduler.step() + + if step >= args.max_steps: + if is_main: + logging.info("Max steps reached, exiting") + if _is_ddp(): + torch.distributed.destroy_process_group() + sys.exit(0) + + endtime = timeit.default_timer() + + +if __name__ == "__main__": + main() diff --git a/graph2mat_ft/__init__.py b/graph2mat_ft/__init__.py new file mode 100644 index 0000000..481fe10 --- /dev/null +++ b/graph2mat_ft/__init__.py @@ -0,0 +1,42 @@ +"""Graph2Mat-arm infrastructure for the r2SCAN benchmark (PARKED). + +PARKED 2026-05-25. Reasoning (see +``../plan_graph2mat_parked_2026-05-25.md``): + +Graph2Mat's native target is a per-pair atom-centered density +matrix ``D_ab``. VASP outputs only a grid density (not D_ab in any +localized basis), so training Graph2Mat on VASP r2SCAN would +require inventing a CHGCAR -> D_ab projection. Standard LSQR on +that is a 10^6 x 10^6 dense linear system per structure; the +matrix-free + neighbor-cutoff variant is multi-week research-grade +engineering with its own quality ceiling to validate. + +For the LeMat-Rho 3-arm comparison (ChargE3Net, DeepDFT, SALTED), +Graph2Mat is parked. The code below is correct as scaffolding and +ships with green tests; it can be revived if (1) we switch the +training set to a code that natively outputs D_ab (SIESTA, ...) or +(2) someone invests in the matrix-free projection. + +The basis adapter (PointBasis) and IO re-export are still useful +in their own right; left in place. +""" + +from graph2mat_ft.basis import basis_table_for_species, point_basis_for_species +from graph2mat_ft.io import read_chgcar, write_chgcar +from graph2mat_ft.model import Graph2MatModel +from graph2mat_ft.projection import ( + make_basis_configuration, + pack_coeffs_to_point_labels, + unpack_point_labels_to_coeffs, +) + +__all__ = [ + "Graph2MatModel", + "basis_table_for_species", + "make_basis_configuration", + "pack_coeffs_to_point_labels", + "point_basis_for_species", + "read_chgcar", + "unpack_point_labels_to_coeffs", + "write_chgcar", +] diff --git a/graph2mat_ft/basis.py b/graph2mat_ft/basis.py new file mode 100644 index 0000000..a9e7907 --- /dev/null +++ b/graph2mat_ft/basis.py @@ -0,0 +1,63 @@ +"""Adapter from our uniform ``BasisSpec`` to Graph2Mat's ``PointBasis``. + +Graph2Mat ships ``PointBasis`` as the per-species basis description. +For each species, ``PointBasis(type, R, basis, basis_convention)`` +carries the cutoff, the per-l radial count, and the spherical- +harmonic convention. Our ``salted_ft.basis.BasisSpec`` is +species-uniform in v1, so the adapter just expands the same spec +into one PointBasis per species. + +Graph2Mat's expected ``basis`` argument when given a sequence of +ints: the integer at index ``l`` is the number of radial functions +at angular momentum ``l``. So our ``n_radial=4, max_l=4`` maps to +``basis=[4, 4, 4, 4, 4]`` (4 radials at each of l=0..4). The +``basis_size`` Graph2Mat computes from that = sum_l (2l+1) * n_radial += 100, matching ``BasisSpec.n_coeffs_per_atom``. +""" + +from __future__ import annotations + +from typing import Iterable + +from graph2mat import PointBasis + +from salted_ft.basis import BasisSpec + + +def point_basis_for_species(symbol: str, basis_spec: BasisSpec) -> PointBasis: + """Build a Graph2Mat ``PointBasis`` for a single species. + + Parameters + ---------- + symbol : + Atomic symbol (``"H"``, ``"Fe"``, etc.) -- becomes ``PointBasis.type``. + basis_spec : + The same BasisSpec used by salted_ft. cutoff -> ``R``, + n_radial -> uniform per-l radial count, max_l -> length of basis list. + + Returns + ------- + PointBasis with ``basis_size == basis_spec.n_coeffs_per_atom`` + and ``basis_convention == 'spherical'``. + """ + # Per-l radial counts as a list of ints. List index = angular momentum. + per_l_radials = [basis_spec.n_radial] * (basis_spec.max_l + 1) + return PointBasis( + type=symbol, + R=float(basis_spec.cutoff), + basis=per_l_radials, + basis_convention="spherical", + ) + + +def basis_table_for_species( + symbols: Iterable[str], basis_spec: BasisSpec +) -> dict[str, PointBasis]: + """Build a ``{symbol: PointBasis}`` dict for a list of species. + + Duplicates in the input are collapsed. Downstream Graph2Mat data + processors (``BasisTableWithEdges``, etc.) take this dict to know + every basis a structure can have. + """ + unique = list(dict.fromkeys(symbols)) # preserves order, deduplicates + return {s: point_basis_for_species(s, basis_spec) for s in unique} diff --git a/graph2mat_ft/io.py b/graph2mat_ft/io.py new file mode 100644 index 0000000..502929a --- /dev/null +++ b/graph2mat_ft/io.py @@ -0,0 +1,18 @@ +"""CHGCAR file I/O for the Graph2Mat arm. + +The Graph2Mat arm uses the same on-disk format as the SALTED arm +(VASP CHGCAR + pymatgen). To avoid drift between the two arms, the +canonical implementation lives in ``salted_ft.io`` and this module +re-exports it. + +Downstream code that wants the Graph2Mat namespace +(``from graph2mat_ft.io import read_chgcar, write_chgcar``) gets +the same helpers as the SALTED arm, including the +``n_electrons`` rescaling that ICHARG=1 needs. +""" + +from __future__ import annotations + +from salted_ft.io import read_chgcar, write_chgcar + +__all__ = ["read_chgcar", "write_chgcar"] diff --git a/graph2mat_ft/model.py b/graph2mat_ft/model.py new file mode 100644 index 0000000..0de4695 --- /dev/null +++ b/graph2mat_ft/model.py @@ -0,0 +1,103 @@ +"""Graph2MatModel -- wrapper around Graph2Mat coefficient prediction. + +Single-call interface ``coefficients = model(atoms)`` so the +Graph2Mat arm slots into the same evaluation pipeline as ChargE3Net +/ DeepDFT / SALTED. + +Stub mode (``ckpt_path=None``) returns deterministic +position-and-species-dependent coefficients. This is what powers +the unit tests and the end-to-end pipeline plumbing tests in D5; +PR zeta-gamma-prime (D6 train-script follow-up) wires in the real +Graph2Mat backbone. +""" + +from __future__ import annotations + +import hashlib +from pathlib import Path + +import ase +import numpy as np + +from salted_ft.basis import BasisSpec +from salted_ft.projection import reconstruct_grid_from_basis + + +class Graph2MatModel: + """Predict atom-centered basis coefficients for a structure. + + Parameters + ---------- + basis_spec : + Basis the coefficients are defined against. Must match the + spec the trained checkpoint was trained on. + ckpt_path : + Path to a Graph2Mat checkpoint. If ``None`` (default), the + model runs in stub mode: deterministic, position-dependent + fake coefficients useful for testing the surrounding pipeline. + """ + + def __init__( + self, basis_spec: BasisSpec, ckpt_path: str | Path | None = None + ) -> None: + self.basis_spec = basis_spec + self.ckpt_path = Path(ckpt_path) if ckpt_path is not None else None + self._g2m_model = None # populated when the real forward lands in D6 + + def __call__(self, atoms: ase.Atoms) -> np.ndarray: + """Predict coefficients for ``atoms``. + + Returns + ------- + np.ndarray of shape ``(n_atoms, basis_spec.n_coeffs_per_atom)``, + float64, deterministic, finite. + """ + if self.ckpt_path is None: + return self._stub_predict(atoms) + return self._g2m_predict(atoms) + + def reconstruct_density( + self, atoms: ase.Atoms, grid_shape: tuple[int, int, int] + ) -> np.ndarray: + """Predict coefficients, then reconstruct the real-space density. + + Equivalent to:: + + c = model(atoms) + reconstruct_grid_from_basis(c, atoms, grid_shape, basis_spec) + """ + coeffs = self(atoms) + return reconstruct_grid_from_basis(coeffs, atoms, grid_shape, self.basis_spec) + + def _stub_predict(self, atoms: ase.Atoms) -> np.ndarray: + """Deterministic position-dependent coefficients without Graph2Mat. + + Seeded RNG keyed off positions + numbers + basis spec, so + same atoms in -> same coefficients out. Output magnitude is + kept small (factor 1e-3) so reconstructed densities stay in + the metric-test range. + """ + n_atoms = len(atoms) + n_coeffs = self.basis_spec.n_coeffs_per_atom + positions = atoms.get_positions() + numbers = atoms.get_atomic_numbers() + + # Hash every byte: int.from_bytes(...[:16]) would discard atoms + # past index 0 and silently collapse different structures into + # the same seed. + digest = hashlib.blake2b( + positions.astype(np.float64).tobytes() + + numbers.astype(np.int64).tobytes() + + str(self.basis_spec).encode("utf-8"), + digest_size=16, + ).digest() + seed_int = int.from_bytes(digest, byteorder="little", signed=False) + rng = np.random.default_rng(seed_int) + return rng.standard_normal((n_atoms, n_coeffs), dtype=np.float64) * 1e-3 + + def _g2m_predict(self, atoms: ase.Atoms) -> np.ndarray: + """Real Graph2Mat forward pass. Lands with D6 training driver.""" + raise NotImplementedError( + "Real Graph2Mat forward pass is deferred to D6. " + "Construct Graph2MatModel with ckpt_path=None for stub mode." + ) diff --git a/graph2mat_ft/projection.py b/graph2mat_ft/projection.py new file mode 100644 index 0000000..9d5e7da --- /dev/null +++ b/graph2mat_ft/projection.py @@ -0,0 +1,127 @@ +"""Per-atom coefficient projection for the Graph2Mat arm (PR zeta-beta). + +Path A of the Graph2Mat plan: the regression target is the same +per-atom basis-coefficient vector that SALTED predicts (see +``salted_ft.projection.project_chgcar_to_basis``). Graph2Mat then +acts as a different backbone over the same target. + +This module exposes: + +* ``pack_coeffs_to_point_labels(coeffs, basis_spec, symbols)`` -- + flatten ``(N_atoms, n_coeffs_per_atom)`` into the atom-major + concatenation Graph2Mat consumes as per-node targets. + +* ``unpack_point_labels_to_coeffs(flat, basis_spec, symbols)`` -- + inverse. + +* ``make_basis_configuration(positions, cell, symbols, basis_spec)`` + -- wrap a structure into ``graph2mat.BasisConfiguration`` so it + can be fed to Graph2Mat's data processor without us reaching + into graph2mat internals from the training driver. + +We do not lift the coefficients into a true density-matrix +representation (that was Path B). v1 has no off-site terms. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import numpy as np + +from salted_ft.basis import BasisSpec + + +def pack_coeffs_to_point_labels( + coeffs: np.ndarray, + basis_spec: BasisSpec, + symbols: Sequence[str], +) -> np.ndarray: + """Flatten per-atom coefficients into Graph2Mat per-node labels. + + Parameters + ---------- + coeffs : + ``(N_atoms, n_coeffs_per_atom)`` from ``salted_ft``. + basis_spec : + Locks ``n_coeffs_per_atom``; used to validate shape. + symbols : + Per-atom species symbols. Length must match ``N_atoms``. + + Returns + ------- + 1D array of length ``N_atoms * n_coeffs_per_atom``, atom-major + (atom 0's block first, then atom 1, ...). + """ + if coeffs.shape[1] != basis_spec.n_coeffs_per_atom: + raise ValueError( + f"coeffs has {coeffs.shape[1]} channels per atom but BasisSpec " + f"declares {basis_spec.n_coeffs_per_atom}" + ) + if coeffs.shape[0] != len(symbols): + raise ValueError( + f"coeffs has {coeffs.shape[0]} atoms but got {len(symbols)} symbols" + ) + # ravel keeps the input dtype; explicit C order is the contract we test + return coeffs.reshape(-1).copy() + + +def unpack_point_labels_to_coeffs( + flat: np.ndarray, + basis_spec: BasisSpec, + symbols: Sequence[str], +) -> np.ndarray: + """Inverse of ``pack_coeffs_to_point_labels``.""" + expected = len(symbols) * basis_spec.n_coeffs_per_atom + if flat.shape[0] != expected: + raise ValueError( + f"flat has length {flat.shape[0]} but expected " + f"{len(symbols)} atoms x {basis_spec.n_coeffs_per_atom} " + f"channels = {expected}" + ) + return flat.reshape(len(symbols), basis_spec.n_coeffs_per_atom).copy() + + +def make_basis_configuration( + positions: np.ndarray, + cell: np.ndarray, + symbols: Sequence[str], + basis_spec: BasisSpec, +): + """Bundle one structure into a Graph2Mat ``BasisConfiguration``. + + The basis list is built once per call from the unique species in + ``symbols`` so the resulting config carries only the species it + actually contains (a downstream BasisTableWithEdges may union + these across the dataset). + + Parameters + ---------- + positions : + ``(N_atoms, 3)`` Cartesian atomic positions in Angstroms. + cell : + ``(3, 3)`` lattice matrix. + symbols : + Per-atom species symbols. + basis_spec : + Defines the per-species ``PointBasis`` (uniform across + species in v1). + """ + # Lazy-import keeps the module importable without graph2mat installed + # (the test class importorskips, so this only runs when present). + from graph2mat import BasisConfiguration + + from graph2mat_ft.basis import basis_table_for_species + + table = basis_table_for_species(symbols, basis_spec) + basis_list = list(table.values()) + symbol_to_idx = {pb.type: i for i, pb in enumerate(basis_list)} + point_types = np.array([symbol_to_idx[s] for s in symbols], dtype=np.int64) + + return BasisConfiguration( + point_types=point_types, + positions=np.asarray(positions, dtype=np.float64), + basis=basis_list, + cell=np.asarray(cell, dtype=np.float64), + pbc=(True, True, True), + ) diff --git a/pyproject.toml b/pyproject.toml index 08f1061..fa608d2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,9 @@ dependencies = [ "pyarrow>=14.0.0", "wandb>=0.16.0", "python-dotenv>=1.0.0", + "metatensor>=0.2.0", + "chemfiles>=0.10.4", + "graph2mat>=0.0.13", ] [tool.uv.sources] diff --git a/salted_ft/__init__.py b/salted_ft/__init__.py new file mode 100644 index 0000000..7655b6e --- /dev/null +++ b/salted_ft/__init__.py @@ -0,0 +1,17 @@ +"""SALTED-arm basis-expansion infrastructure for the r2SCAN benchmark. + +This package wraps rholearn (`lab-cosmo/rholearn`) and provides the +projection/reconstruction bridge between LeMat-Rho VASP CHGCAR data +and the rholearn training/inference pipeline. + +Layout (stacked PRs, see `plan_salted_graph2mat_basis_choice_may_20_pm.md`): + +* ``basis.py`` (PR α) — ``BasisSpec`` dataclass + shape helpers. +* ``projection.py`` (PR β) — VASP CHGCAR ↔ basis coefficients. +* ``model.py`` (PR γ) — ``SALTEDModel`` wrapper for rholearn. +* ``io.py`` (PR δ) — coefficients/grid ↔ pymatgen ``Chgcar``. +""" + +from salted_ft.basis import BasisSpec + +__all__ = ["BasisSpec"] diff --git a/salted_ft/basis.py b/salted_ft/basis.py new file mode 100644 index 0000000..939f660 --- /dev/null +++ b/salted_ft/basis.py @@ -0,0 +1,87 @@ +"""BasisSpec — the atom-centered radial × angular basis used by the SALTED arm. + +The density expansion is +:: + + rho(r) = sum_i sum_{nlm} c_{i,nlm} phi_{n}(|r - r_i|) Y_{lm}(r - r_i) + +with ``phi_n`` a Gaussian radial of width ``sigma_n`` and ``Y_lm`` a real +spherical harmonic. + +Numbers locked in Phase A4 of +``plan_salted_graph2mat_basis_choice_may_20_pm.md`` (2026-05-20): +``max_l=4``, ``n_radial=4``, ``sigma=(0.5, 1.0, 2.0, 4.0)``, ``cutoff=4.0``. +That gives 100 coefficients per atom (4 × (4+1)²), which lands the trained +model in the same parameter-count ballpark as ChargE3Net for fair comparison. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + + +@dataclass(frozen=True) +class BasisSpec: + """Configuration of the atom-centered Gaussian × Y_lm basis. + + Parameters + ---------- + max_l : + Maximum angular momentum, inclusive. Real spherical harmonics + Y_lm with ``l = 0..max_l`` and ``m = -l..l`` are used. + n_radial : + Number of radial channels. Must match ``len(sigma)``. + sigma : + Gaussian widths (Angstrom), one per radial channel. + cutoff : + Radial cutoff (Angstrom) beyond which basis functions are zero. + Should match the cutoff used by the neighbor-list / graph + constructor of the downstream ML model. + """ + + max_l: int = 4 + n_radial: int = 4 + sigma: tuple[float, ...] = field(default=(0.5, 1.0, 2.0, 4.0)) + cutoff: float = 4.0 + + def __post_init__(self) -> None: + # All validation goes here so a malformed spec raises at construction + # time, not deep inside a tensor op three PRs from now. + if self.max_l < 0: + raise ValueError( + f"max_l must be >= 0; got {self.max_l}. " + "Use max_l=0 for an s-only basis." + ) + if self.n_radial < 1: + raise ValueError( + f"n_radial must be >= 1; got {self.n_radial}. " + "A basis with zero radial channels has no expressive power." + ) + if len(self.sigma) != self.n_radial: + raise ValueError( + f"n_radial ({self.n_radial}) must equal len(sigma) " + f"({len(self.sigma)}); each radial channel needs its own width." + ) + if any(s <= 0 for s in self.sigma): + raise ValueError( + f"sigma values must be positive (Gaussian widths); got {self.sigma}." + ) + if self.cutoff <= 0: + raise ValueError( + f"cutoff must be > 0; got {self.cutoff}. " + "A nonpositive cutoff makes the basis identically zero." + ) + + @property + def n_angular_components(self) -> int: + """Number of real-Ylm components for l = 0..max_l: sum_l (2l + 1) = (max_l + 1)^2.""" + return (self.max_l + 1) ** 2 + + @property + def n_coeffs_per_atom(self) -> int: + """Coefficients per atom: n_radial channels × angular components.""" + return self.n_radial * self.n_angular_components + + def total_coeffs_shape(self, n_atoms: int) -> tuple[int, int]: + """Shape of the per-structure coefficients tensor.""" + return (n_atoms, self.n_coeffs_per_atom) diff --git a/salted_ft/io.py b/salted_ft/io.py new file mode 100644 index 0000000..d43c589 --- /dev/null +++ b/salted_ft/io.py @@ -0,0 +1,102 @@ +"""CHGCAR file I/O for the SALTED arm. + +A thin wrapper over pymatgen's ``Chgcar``. The wrapper adds two things +on top of the bare pymatgen API: + +* ``n_electrons`` rescaling. The CHGCAR convention is + ``integrated_density = sum(rho) * cell_volume / N_grid = N_electrons``. + Our predicted densities come from an L2-projected basis with no + guarantee on the integral; we have to rescale so VASP doesn't + silently fix the electron count for us at startup (which would + defeat the speedup measurement). + +* ``ase.Atoms`` input/output to match the rest of the salted_ft + pipeline. pymatgen's ``Structure`` is converted via + ``AseAtomsAdaptor`` and back. + +These two helpers are the boundary between the predicted-density +tensor world and the VASP-input file world. The actual SCF speedup +measurement lives in the entalsim ``StructureVASPSinglePoint`` maker +(separate stack). +""" + +from __future__ import annotations + +from pathlib import Path + +import ase +import numpy as np + + +def write_chgcar( + density: np.ndarray, + atoms: ase.Atoms, + path: str | Path, + n_electrons: float | None = None, +) -> None: + """Write a real-space density grid to a VASP CHGCAR file. + + Parameters + ---------- + density : + Real-space density on a regular grid, shape ``(Nx, Ny, Nz)``. + atoms : + Periodic structure; provides cell + species ordering. + path : + Output file path. + n_electrons : + If given (and > 0), rescale the density so the file's integrated + density equals this value. VASP reads this as the total electron + count when starting with ``ICHARG=1``; getting it right is + what makes the SCF-speedup comparison meaningful. + """ + if density.ndim != 3: + raise ValueError( + f"density must be a 3D grid (Nx, Ny, Nz); got shape {density.shape}" + ) + if n_electrons is not None and n_electrons <= 0: + raise ValueError( + f"n_electrons must be > 0; got {n_electrons}. Use None to skip rescaling." + ) + + from pymatgen.io.ase import AseAtomsAdaptor + from pymatgen.io.vasp.outputs import Chgcar + + structure = AseAtomsAdaptor.get_structure(atoms) + rho = np.asarray(density, dtype=np.float64).copy() + + if n_electrons is not None: + cell_volume = float(structure.lattice.volume) + n_grid = int(np.prod(rho.shape)) + current_total = rho.sum() * cell_volume / n_grid + if current_total != 0.0: + rho *= n_electrons / current_total + + # pymatgen's Chgcar stores density as the per-cell sum (not per-grid-point); + # i.e. rho_stored = rho * cell_volume in its convention. The Chgcar + # constructor expects the data dict to use the same convention as VASP's + # CHGCAR file format, which is rho * volume. We multiply here so the + # round-trip via Chgcar.from_file preserves our user-facing rho. + chgcar_data = {"total": rho * float(structure.lattice.volume)} + chgcar = Chgcar(structure, chgcar_data) + chgcar.write_file(str(path)) + + +def read_chgcar(path: str | Path) -> tuple[np.ndarray, ase.Atoms]: + """Read a CHGCAR file and return ``(density, atoms)``. + + Returns + ------- + density : np.ndarray of shape ``(Nx, Ny, Nz)``, the density per + grid point (the inverse of write_chgcar's convention). + atoms : ase.Atoms + """ + from pymatgen.io.ase import AseAtomsAdaptor + from pymatgen.io.vasp.outputs import Chgcar + + chgcar = Chgcar.from_file(str(path)) + cell_volume = float(chgcar.structure.lattice.volume) + # VASP stores density * volume; undo that for the user-facing density. + rho = np.asarray(chgcar.data["total"], dtype=np.float64) / cell_volume + atoms = AseAtomsAdaptor.get_atoms(chgcar.structure) + return rho, atoms diff --git a/salted_ft/model.py b/salted_ft/model.py new file mode 100644 index 0000000..4e5160a --- /dev/null +++ b/salted_ft/model.py @@ -0,0 +1,170 @@ +"""SALTEDModel — wrapper around rholearn's basis-coefficient prediction. + +The wrapper exposes a single-call interface +``coefficients = model(atoms)`` so the SALTED arm slots into the same +evaluation pipeline as ChargE3Net / DeepDFT: predict, reconstruct on +the VASP FFT grid, compare against the converged density via NMAPE +and friends. + +When constructed with ``ckpt_path=None`` the model is in **stub mode**: +it returns deterministic, position-dependent coefficients without +requiring a trained rholearn checkpoint. This is what powers the +unit tests and the end-to-end pipeline plumbing tests during PR +gamma; PR gamma-prime (a follow-up) will swap in real rholearn +forward calls. + +When ``ckpt_path`` points at a real rholearn checkpoint the model +delegates to rholearn. The rholearn sibling repo is expected at +``../rholearn/`` relative to the LeMat-Rho clone (same pattern as +``charge3net`` for ChargE3Net and ``DeepDFT`` for DeepDFT). +""" + +from __future__ import annotations + +import hashlib +import sys +from pathlib import Path + +import ase +import numpy as np + +from salted_ft.basis import BasisSpec +from salted_ft.projection import reconstruct_grid_from_basis + +# rholearn sibling-repo discovery follows the same pattern as +# charge3net_ft/model.py and deepdft_ft/runner.py. Resolution is lazy: +# we only insist on the sibling repo when ckpt_path is provided. +_RHOLEARN_ROOT = Path(__file__).resolve().parent.parent.parent / "rholearn" + + +def _ensure_rholearn_importable() -> None: + """Make ``rholearn`` importable; only called when ckpt_path is set.""" + if not _RHOLEARN_ROOT.exists(): + raise RuntimeError( + f"rholearn repo not found at {_RHOLEARN_ROOT}.\n" + "Clone it with: git clone https://github.com/lab-cosmo/rholearn " + f"{_RHOLEARN_ROOT}\n" + "Note: the metatensor.torch.atomistic -> metatomic.torch namespace " + "patch in rholearn/utils/system.py may also be required." + ) + if str(_RHOLEARN_ROOT) not in sys.path: + sys.path.insert(0, str(_RHOLEARN_ROOT)) + + +class SALTEDModel: + """Predict atom-centered basis coefficients for a structure. + + Parameters + ---------- + basis_spec : + The basis the coefficients are defined against. Must match the + spec the trained checkpoint was trained on. + ckpt_path : + Path to a rholearn checkpoint. If ``None`` (default), the model + runs in stub mode: deterministic, position-dependent fake + coefficients useful for testing the surrounding pipeline. + """ + + def __init__( + self, basis_spec: BasisSpec, ckpt_path: str | Path | None = None + ) -> None: + self.basis_spec = basis_spec + self.ckpt_path = Path(ckpt_path) if ckpt_path is not None else None + # Lazy: model load is deferred to the first inference call. + # Renamed _rholearn_model -> _model would be more accurate now + # that the baseline path replaced the rholearn forward, but kept + # for diff-minimality. + self._rholearn_model = None + + def __call__(self, atoms: ase.Atoms) -> np.ndarray: + """Predict coefficients for ``atoms``. + + Returns + ------- + np.ndarray of shape ``(n_atoms, basis_spec.n_coeffs_per_atom)``, + float64, deterministic, finite. + """ + if self.ckpt_path is None: + return self._stub_predict(atoms) + return self._baseline_predict(atoms) + + def reconstruct_density( + self, atoms: ase.Atoms, grid_shape: tuple[int, int, int] + ) -> np.ndarray: + """Predict coefficients, then reconstruct the real-space density. + + Equivalent to:: + + c = model(atoms) + reconstruct_grid_from_basis(c, atoms, grid_shape, basis_spec) + + Provided as a convenience for the VASP comparison pipeline, + which always wants the grid form. + """ + coeffs = self(atoms) + return reconstruct_grid_from_basis(coeffs, atoms, grid_shape, self.basis_spec) + + # ------------------------------------------------------------------ + # Implementations + # ------------------------------------------------------------------ + def _stub_predict(self, atoms: ase.Atoms) -> np.ndarray: + """Deterministic position-dependent coefficients without rholearn. + + Recipe: seed a NumPy random generator with a hash of the atomic + positions, atomic numbers, and basis spec. Same atoms in -> same + coefficients out. Different atom positions -> different coefficients. + + The numbers are small (order 1e-3) so reconstructed densities + don't blow up the metric ranges in downstream tests. + """ + n_atoms = len(atoms) + n_coeffs = self.basis_spec.n_coeffs_per_atom + positions = atoms.get_positions() + numbers = atoms.get_atomic_numbers() + + # Hash every byte: int.from_bytes(...[:16]) would discard atoms + # past index 0 and silently collapse different structures into + # the same seed. + digest = hashlib.blake2b( + positions.astype(np.float64).tobytes() + + numbers.astype(np.int64).tobytes() + + str(self.basis_spec).encode("utf-8"), + digest_size=16, + ).digest() + seed_int = int.from_bytes(digest, byteorder="little", signed=False) + rng = np.random.default_rng(seed_int) + return rng.standard_normal((n_atoms, n_coeffs), dtype=np.float64) * 1e-3 + + def _baseline_predict(self, atoms: ase.Atoms) -> np.ndarray: + """Load the D6 SchNet-style baseline ckpt and predict coefficients. + + Ckpt format (see ``salted_ft.train_baseline.train``):: + + {"basis_spec": BasisSpec, "model": state_dict} + + The baseline model is cached on first call to amortise the + load over many predictions. + """ + # Lazy import: torch is heavy and stub mode does not require it. + import torch + + from salted_ft.train_baseline import SaltedBaselineModel + + if self._rholearn_model is None: + state = torch.load( + str(self.ckpt_path), map_location="cpu", weights_only=False + ) + if "model" not in state: + raise RuntimeError( + f"Checkpoint at {self.ckpt_path} is not in the expected " + "baseline format ({'basis_spec': ..., 'model': state_dict}). " + "If this is a rholearn checkpoint, that path is deferred." + ) + model = SaltedBaselineModel(state.get("basis_spec", self.basis_spec)) + model.load_state_dict(state["model"]) + model.train(False) + self._rholearn_model = model + + with torch.no_grad(): + pred = self._rholearn_model(atoms) + return pred.detach().cpu().numpy().astype(np.float64) diff --git a/salted_ft/project_dataset.py b/salted_ft/project_dataset.py new file mode 100644 index 0000000..01b9729 --- /dev/null +++ b/salted_ft/project_dataset.py @@ -0,0 +1,170 @@ +"""Phase D2: project the LeMat-Rho parquet dataset onto the SALTED basis. + +One-time job. Reads every ``chunk_*.parquet`` produced by +lematerial-fetcher (rows of densities + structures), runs +``project_chgcar_to_basis`` row by row, writes a parallel +``chunk_*.parquet`` of basis coefficients that downstream training +loops (rholearn, Graph2Mat, etc.) consume. + +Output schema per row:: + + row_index int position in the source chunk + material_id str carried from source if present, else "" + n_atoms int + atomic_numbers list[int] ASE atomic numbers, length n_atoms + lattice_vectors list[list] 3x3 cell matrix in Angstrom + n_electrons float integrated density * cell_volume / n_grid + grid_shape list[int] [Nx, Ny, Nz] + coefficients list[list] (n_atoms, n_coeffs_per_atom) + basis_set_NMAPE float per-row reconstruction error (%) + +CLI:: + + uv run python -m salted_ft.project_dataset \\ + --input-dir $SETUP/charge3net_data \\ + --output-dir $SETUP/salted_projected_coefficients +""" + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + +from charge3net_ft.data import _COLUMNS, _row_to_atoms_and_density +from salted_ft.basis import BasisSpec +from salted_ft.projection import ( + project_chgcar_to_basis, + reconstruct_grid_from_basis, +) + + +def _row_nmape(true: np.ndarray, pred: np.ndarray) -> float: + """Integral-normalised mean absolute percentage error (%) for one row.""" + return float(100.0 * np.sum(np.abs(true - pred)) / (np.sum(np.abs(true)) + 1e-12)) + + +def project_chunk( + in_path: str | Path, + out_path: str | Path, + basis_spec: BasisSpec, +) -> None: + """Project every valid row in ``in_path`` and write ``out_path``.""" + in_path = Path(in_path) + out_path = Path(out_path) + out_path.parent.mkdir(parents=True, exist_ok=True) + + columns = list(_COLUMNS) + # material_id is optional; include it if present so downstream can match + # to the source LeMat-Rho row. + schema = pq.read_schema(in_path) + has_material_id = "material_id" in schema.names + if has_material_id: + columns.append("material_id") + + table = pq.read_table(in_path, columns=columns) + n_rows = len(table) + + out_rows: list[dict] = [] + for ri in range(n_rows): + chgd = table.column("compressed_charge_density")[ri] + if not chgd.is_valid: + continue # skip null density (failed DFT extraction in source) + + row = {col: table.column(col)[ri].as_py() for col in _COLUMNS} + atoms, density, _origin = _row_to_atoms_and_density(row) + + coeffs = project_chgcar_to_basis(density, atoms, basis_spec) + reconstructed = reconstruct_grid_from_basis( + coeffs, atoms, density.shape, basis_spec + ) + nmape = _row_nmape(density, reconstructed) + + cell = np.asarray(atoms.get_cell(), dtype=np.float64) + cell_volume = float(np.abs(np.linalg.det(cell))) + n_grid = int(np.prod(density.shape)) + n_electrons = float(density.sum() * cell_volume / n_grid) + + out_rows.append( + { + "row_index": ri, + "material_id": ( + table.column("material_id")[ri].as_py() if has_material_id else "" + ), + "n_atoms": int(len(atoms)), + "atomic_numbers": atoms.get_atomic_numbers().tolist(), + "lattice_vectors": cell.tolist(), + "n_electrons": n_electrons, + "grid_shape": list(density.shape), + "coefficients": coeffs.tolist(), + "basis_set_NMAPE": nmape, + } + ) + + out_table = pa.Table.from_pylist(out_rows) + pq.write_table(out_table, out_path) + + +def project_directory( + input_dir: str | Path, + output_dir: str | Path, + basis_spec: BasisSpec, +) -> None: + """Run :func:`project_chunk` over every ``chunk_*.parquet`` in ``input_dir``. + + Idempotent: a chunk whose output already exists is left untouched + so partially-completed runs can resume cheaply. + """ + input_dir = Path(input_dir) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + inputs = sorted(input_dir.glob("chunk_*.parquet")) + if not inputs: + raise FileNotFoundError(f"no chunk_*.parquet files under {input_dir}") + + for in_path in inputs: + out_path = output_dir / in_path.name + if out_path.exists() and out_path.stat().st_size > 0: + continue + project_chunk(in_path, out_path, basis_spec) + + +def _main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description="Project the LeMat-Rho parquet dataset onto the SALTED basis." + ) + parser.add_argument("--input-dir", required=True, type=Path) + parser.add_argument("--output-dir", required=True, type=Path) + parser.add_argument( + "--basis-spec", + type=str, + default=None, + help="JSON-encoded BasisSpec overrides. If omitted, defaults are used.", + ) + args = parser.parse_args(argv) + + if args.basis_spec: + overrides = json.loads(args.basis_spec) + # sigma must be tuple-ified to satisfy BasisSpec's frozen dataclass + if "sigma" in overrides: + overrides["sigma"] = tuple(overrides["sigma"]) + spec = BasisSpec(**overrides) + else: + spec = BasisSpec() + print( + f"BasisSpec: lmax={spec.max_l}, n_radial={spec.n_radial}, " + f"sigma={spec.sigma}, cutoff={spec.cutoff}, " + f"n_coeffs_per_atom={spec.n_coeffs_per_atom}" + ) + + project_directory(args.input_dir, args.output_dir, spec) + return 0 + + +if __name__ == "__main__": + raise SystemExit(_main()) diff --git a/salted_ft/projection.py b/salted_ft/projection.py new file mode 100644 index 0000000..4ebc895 --- /dev/null +++ b/salted_ft/projection.py @@ -0,0 +1,323 @@ +"""VASP density grid <-> atom-centered Gaussian * Y_lm basis coefficients. + +The two operations defined here are the DIY bridge between VASP plane-wave +CHGCAR data and the rholearn/SALTED localized-basis world. See the memo +``plan_salted_graph2mat_basis_choice_may_20_pm.md`` (Phase A) for why we +have to build this layer ourselves. + +Math +---- + +The basis expansion is +:: + + rho(r) ~= sum_i sum_n sum_{l,m} c_{i,n,l,m} phi_n(|r - r_i|) Y_lm(rhat) + +where ``i`` indexes atoms, ``n`` is the radial channel, ``(l, m)`` are the +real spherical harmonic indices, ``phi_n`` is a Gaussian of width +``sigma_n``, and ``Y_lm`` is a real spherical harmonic. + +We use the **orthonormal-approximation projection**: each coefficient is +the inner product of the density with the corresponding basis function, +normalized by the basis function's L2 norm. This is exact iff the basis +is orthonormal; for our Gaussians it's a v1 stand-in for a proper +overlap-matrix least-squares solve, which lands in a follow-up PR. + +Reconstruction is the literal sum on the right-hand side. + +Both maps are linear in their input (linearity is a pinned test). + +PBC +--- + +Minimum-image convention via the cell inverse. Each grid point sees each +atom at its closest periodic image. Adequate for cells where 2*cutoff +fits inside the smallest cell vector; for very small cells we'd want +full Ewald-style supercell expansion. Not in scope for PR beta. +""" + +from __future__ import annotations + +import ase +import numpy as np + +from salted_ft.basis import BasisSpec + + +# --------------------------------------------------------------------------- +# Grid-position generation (matches charge3net's `calculate_grid_pos` plus +# `deepdft_ft.data._calculate_grid_pos` so the three pipelines agree on +# where grid point (i, j, k) lives in space). +# --------------------------------------------------------------------------- +def _grid_positions(grid_shape: tuple[int, int, int], cell: np.ndarray) -> np.ndarray: + """Cartesian coordinates of every grid point. + + Parameters + ---------- + grid_shape : (Nx, Ny, Nz) + cell : (3, 3) lattice matrix with rows as vectors + + Returns + ------- + (Nx * Ny * Nz, 3) Cartesian coordinates, ``[i, j, k]`` order matching + ``np.ravel`` of an array of that shape. + """ + # Silence harmless RuntimeWarnings from intermediate matmul reductions. + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + Nx, Ny, Nz = grid_shape + fx = np.arange(Nx, dtype=np.float64) / Nx + fy = np.arange(Ny, dtype=np.float64) / Ny + fz = np.arange(Nz, dtype=np.float64) / Nz + fX, fY, fZ = np.meshgrid(fx, fy, fz, indexing="ij") + frac = np.stack([fX.ravel(), fY.ravel(), fZ.ravel()], axis=-1) + return frac @ cell # (n_grid, 3) + + +# --------------------------------------------------------------------------- +# Real spherical harmonics. We hand-roll real Y_lm for lmax up to 4 +# (covers our default lmax=4) because the alternatives are either heavy +# (e3nn/torch in a pure-numpy module) or complex-valued (scipy.special). +# --------------------------------------------------------------------------- +_SQRT_PI = np.sqrt(np.pi) + + +def _real_sph_harm(rhat: np.ndarray, lmax: int) -> np.ndarray: + """Real spherical harmonics on unit vectors, l = 0..lmax inclusive. + + Returns an array of shape ``(..., (lmax + 1) ** 2)`` where the last + axis is ordered ``[Y_00, Y_1{-1}, Y_10, Y_11, Y_2{-2}, ..., Y_l l]`` + (the standard SOAP / SALTED ordering). + + Parameters + ---------- + rhat : (..., 3) array + Unit vectors. Zero-length inputs are handled by the caller. + lmax : + Maximum angular momentum, inclusive. + """ + if lmax > 4: + raise NotImplementedError( + f"_real_sph_harm only implements l = 0..4 (lmax={lmax} requested). " + "Extend or swap in e3nn.o3.spherical_harmonics for higher lmax." + ) + x, y, z = rhat[..., 0], rhat[..., 1], rhat[..., 2] + n_lm = (lmax + 1) ** 2 + out = np.empty(rhat.shape[:-1] + (n_lm,), dtype=np.float64) + + # l = 0 + out[..., 0] = 0.5 / _SQRT_PI + + if lmax >= 1: + # l = 1: Y_1{-1} ~ y, Y_10 ~ z, Y_11 ~ x + c1 = 0.5 * np.sqrt(3.0 / np.pi) + out[..., 1] = c1 * y + out[..., 2] = c1 * z + out[..., 3] = c1 * x + + if lmax >= 2: + # l = 2 + c2_xy = 0.5 * np.sqrt(15.0 / np.pi) # Y_2{-2}, Y_21, Y_2{-1} prefactors + c2_z2 = 0.25 * np.sqrt(5.0 / np.pi) + c2_x2y2 = 0.25 * np.sqrt(15.0 / np.pi) + out[..., 4] = c2_xy * x * y # Y_2{-2} + out[..., 5] = c2_xy * y * z # Y_2{-1} + out[..., 6] = c2_z2 * (3 * z * z - 1) # Y_20 + out[..., 7] = c2_xy * x * z # Y_21 + out[..., 8] = c2_x2y2 * (x * x - y * y) # Y_22 + + if lmax >= 3: + # l = 3 + c3a = 0.25 * np.sqrt(35.0 / (2.0 * np.pi)) + c3b = 0.5 * np.sqrt(105.0 / np.pi) + c3c = 0.25 * np.sqrt(21.0 / (2.0 * np.pi)) + c3d = 0.25 * np.sqrt(7.0 / np.pi) + out[..., 9] = c3a * y * (3 * x * x - y * y) # Y_3{-3} + out[..., 10] = c3b * x * y * z # Y_3{-2} + out[..., 11] = c3c * y * (5 * z * z - 1) # Y_3{-1} + out[..., 12] = c3d * z * (5 * z * z - 3) # Y_30 + out[..., 13] = c3c * x * (5 * z * z - 1) # Y_31 + out[..., 14] = 0.25 * np.sqrt(105.0 / np.pi) * z * (x * x - y * y) # Y_32 + out[..., 15] = c3a * x * (x * x - 3 * y * y) # Y_33 + + if lmax >= 4: + # l = 4 + c4a = 0.75 * np.sqrt(35.0 / np.pi) + c4b = 0.75 * np.sqrt(35.0 / (2.0 * np.pi)) + c4c = 0.75 * np.sqrt(5.0 / np.pi) + c4d = 0.75 * np.sqrt(5.0 / (2.0 * np.pi)) + c4e = 3.0 / 16.0 * np.sqrt(1.0 / np.pi) + out[..., 16] = c4a * x * y * (x * x - y * y) # Y_4{-4} + out[..., 17] = c4b * y * z * (3 * x * x - y * y) # Y_4{-3} + out[..., 18] = c4c * x * y * (7 * z * z - 1) # Y_4{-2} + out[..., 19] = c4d * y * z * (7 * z * z - 3) # Y_4{-1} + out[..., 20] = c4e * (35 * z**4 - 30 * z * z + 3) # Y_40 + out[..., 21] = c4d * x * z * (7 * z * z - 3) # Y_41 + out[..., 22] = ( + 0.375 * np.sqrt(5.0 / np.pi) * (x * x - y * y) * (7 * z * z - 1) + ) # Y_42 + out[..., 23] = c4b * x * z * (x * x - 3 * y * y) # Y_43 + out[..., 24] = ( + 0.1875 + * np.sqrt(35.0 / np.pi) + * (x * x * (x * x - 3 * y * y) - y * y * (3 * x * x - y * y)) + ) # Y_44 + + return out + + +# --------------------------------------------------------------------------- +# Per-atom basis-function evaluation at grid points +# --------------------------------------------------------------------------- +def _eval_basis_at_grid( + atom_position: np.ndarray, + grid_positions: np.ndarray, + cell: np.ndarray, + basis_spec: BasisSpec, +) -> np.ndarray: + """Evaluate every basis function centered on ``atom_position`` at every + grid point, using minimum-image convention. + + Returns ``(n_grid, n_coeffs_per_atom)`` array of basis-function values. + """ + # The masked points outside the cutoff intentionally produce some + # 0/0 and large-magnitude intermediates whose results we throw away + # via ``mask``. Silence the harmless RuntimeWarnings to keep test + # output readable. + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + inv_cell = np.linalg.inv(cell) + rel = grid_positions - atom_position[None, :] # (n_grid, 3) + # Minimum-image: wrap fractional displacement to [-0.5, 0.5] + frac_disp = rel @ inv_cell + frac_disp = frac_disp - np.round(frac_disp) + rel = frac_disp @ cell # (n_grid, 3) in Cartesian, wrapped + + r = np.linalg.norm(rel, axis=-1) # (n_grid,) + mask = r < basis_spec.cutoff + r_safe = np.where(r > 0, r, 1.0) + rhat = rel / r_safe[:, None] + + # Real spherical harmonics, (n_grid, (lmax+1)^2) + ylm = _real_sph_harm(rhat, basis_spec.max_l) + + n_grid = grid_positions.shape[0] + n_lm = ylm.shape[-1] + n_radial = basis_spec.n_radial + out = np.empty((n_grid, n_radial * n_lm), dtype=np.float64) + + for n_idx, sigma in enumerate(basis_spec.sigma): + radial = np.exp(-0.5 * (r / sigma) ** 2) * mask # (n_grid,) + # block layout: [n=0 lm=0..nlm-1, n=1 lm=0..nlm-1, ...] + out[:, n_idx * n_lm : (n_idx + 1) * n_lm] = radial[:, None] * ylm + + return out + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- +def project_chgcar_to_basis( + density_grid: np.ndarray, + atoms: ase.Atoms, + basis_spec: BasisSpec, +) -> np.ndarray: + """Project a real-space density grid onto the atom-centered basis. + + Uses orthonormal-approximation: each coefficient is the L2 inner + product of the density with the corresponding basis function, + divided by the basis function's own squared L2 norm. Exact when + the basis is orthonormal; a v1 stand-in until PR gamma (which will + swap in proper overlap-matrix LSQR). + + Parameters + ---------- + density_grid : (Nx, Ny, Nz) array + Real-space density on the grid (CHGCAR-like). + atoms : ase.Atoms + Periodic structure. Provides positions and cell. + basis_spec : BasisSpec + Basis to project onto. + + Returns + ------- + (n_atoms, n_coeffs_per_atom) float64 array of coefficients. + """ + grid_shape = density_grid.shape + cell = np.asarray(atoms.get_cell()) + grid_pos = _grid_positions(grid_shape, cell) # (n_grid, 3) + rho_flat = density_grid.astype(np.float64).ravel() # (n_grid,) + + n_atoms = len(atoms) + coeffs = np.zeros((n_atoms, basis_spec.n_coeffs_per_atom), dtype=np.float64) + positions = atoms.get_positions() + + # Build the full per-structure design matrix B_global of shape + # (n_grid, n_atoms * n_coeffs_per_atom) and solve a single least- + # squares system for ALL atoms' coefficients simultaneously. This + # is the correct way to handle the strong overlap between our + # Gaussian basis functions (sigma ~ cutoff means heavy overlap). + # + # The previous orthonormal-approx (numer/denom per channel) + # produced ~1000% NMAPE on real LeMat-Rho rows because it + # overcounted contributions from overlapping basis functions + # (recorded in D1 sanity check, 2026-05-21). + n_per_atom = basis_spec.n_coeffs_per_atom + B_global = np.empty((grid_pos.shape[0], n_atoms * n_per_atom), dtype=np.float64) + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + for i, pos in enumerate(positions): + B_global[:, i * n_per_atom : (i + 1) * n_per_atom] = _eval_basis_at_grid( + pos, grid_pos, cell, basis_spec + ) + # lstsq is overdetermined (n_grid > n_atoms * n_per_atom for our + # 10x10x10 grids), so the solution is the unique minimum-residual + # least-squares fit. + c_flat, *_ = np.linalg.lstsq(B_global, rho_flat, rcond=None) + coeffs = c_flat.reshape(n_atoms, n_per_atom) + + return coeffs + + +def reconstruct_grid_from_basis( + coefficients: np.ndarray, + atoms: ase.Atoms, + grid_shape: tuple[int, int, int], + basis_spec: BasisSpec, +) -> np.ndarray: + """Reconstruct a density grid from per-atom basis coefficients. + + Just evaluates the basis at every grid point and contracts with the + coefficients. The reverse of ``project_chgcar_to_basis`` in the + sense that ``reconstruct(project(rho))`` is the best basis-set + approximation to ``rho``. + + Parameters + ---------- + coefficients : (n_atoms, n_coeffs_per_atom) array + atoms : ase.Atoms + grid_shape : (Nx, Ny, Nz) + basis_spec : BasisSpec + + Returns + ------- + (Nx, Ny, Nz) float64 density grid. + """ + n_atoms = len(atoms) + if coefficients.shape != (n_atoms, basis_spec.n_coeffs_per_atom): + raise ValueError( + f"coefficients shape {coefficients.shape} mismatches " + f"({n_atoms}, {basis_spec.n_coeffs_per_atom})" + ) + + cell = np.asarray(atoms.get_cell()) + grid_pos = _grid_positions(grid_shape, cell) + positions = atoms.get_positions() + + rho_flat = np.zeros(grid_pos.shape[0], dtype=np.float64) + coefficients = coefficients.astype(np.float64) + # Same harmless matmul warnings from masked-out grid points as in + # _eval_basis_at_grid; silence them at the caller too. + with np.errstate(divide="ignore", invalid="ignore", over="ignore"): + for i, pos in enumerate(positions): + B = _eval_basis_at_grid(pos, grid_pos, cell, basis_spec) + rho_flat += B @ coefficients[i] + + return rho_flat.reshape(grid_shape) diff --git a/salted_ft/rholearn_adapter.py b/salted_ft/rholearn_adapter.py new file mode 100644 index 0000000..94ff5e5 --- /dev/null +++ b/salted_ft/rholearn_adapter.py @@ -0,0 +1,205 @@ +"""SALTED -> rholearn data-format adapter. + +rholearn's training loop consumes basis-coefficient vectors in +metatensor ``TensorMap`` format, with a specific flat-vector layout +that differs from our internal one: + +================== =================================================== +Our layout atom (outer) -> n (radial) -> lambda -> mu + (this is what ``project_chgcar_to_basis`` returns) +rholearn layout atom (outer) -> lambda -> n (radial) -> mu + (see ``rholearn/utils/convert.py::_get_flat_index``) +================== =================================================== + +This module provides three things: + +1. ``build_lmax_nmax(basis_spec, species)`` -- our uniform BasisSpec + expanded into rholearn's per-species ``lmax`` / ``nmax`` dicts. +2. ``dense_to_rholearn_flat`` / ``rholearn_flat_to_dense`` -- the + permutation between the two layouts, ndarray <-> ndarray. Roundtrip + is exact and pinned by tests. +3. ``dense_to_tensormap`` -- the full path that calls rholearn's + ``convert.coeff_vector_ndarray_to_tensormap`` to produce a + ``metatensor.TensorMap``. Lazy-imports rholearn / metatensor. + +The permutation is the load-bearing piece. Get it wrong and rholearn +trains on misordered data; the value at index k of the flat vector +no longer corresponds to the (lambda, n, mu) channel rholearn thinks +it does. +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Iterable + +import numpy as np + +from salted_ft.basis import BasisSpec + + +# Path setup for lazy rholearn import. Same pattern as +# charge3net_ft/model.py and deepdft_ft/runner.py. +_RHOLEARN_ROOT = Path(__file__).resolve().parent.parent.parent / "rholearn" + + +def _ensure_rholearn_importable() -> None: + if not _RHOLEARN_ROOT.exists(): + raise RuntimeError( + f"rholearn repo not found at {_RHOLEARN_ROOT}.\n" + "Clone it with: git clone https://github.com/lab-cosmo/rholearn " + f"{_RHOLEARN_ROOT}" + ) + if str(_RHOLEARN_ROOT) not in sys.path: + sys.path.insert(0, str(_RHOLEARN_ROOT)) + + +# --------------------------------------------------------------------------- +# Basis spec dict builder +# --------------------------------------------------------------------------- +def build_lmax_nmax( + basis_spec: BasisSpec, species: Iterable[str] +) -> tuple[dict[str, int], dict[tuple[str, int], int]]: + """Expand our uniform BasisSpec into rholearn's per-species dicts. + + Returns + ------- + lmax : ``{species: max_l}`` for every species in ``species`` + nmax : ``{(species, lambda): n_radial}`` for every (species, lambda) + """ + species = list(species) + lmax = {s: basis_spec.max_l for s in species} + nmax = { + (s, lam): basis_spec.n_radial + for s in species + for lam in range(basis_spec.max_l + 1) + } + return lmax, nmax + + +# --------------------------------------------------------------------------- +# Permutation between our layout and rholearn's +# --------------------------------------------------------------------------- +def _our_to_rholearn_permutation(basis_spec: BasisSpec) -> np.ndarray: + """Return the index permutation ``p`` such that ``rholearn_flat[k] == + our_flat[p[k]]`` for a SINGLE atom. + + Our per-atom layout (length ``n_radial * (max_l + 1) ** 2``): + for n in 0..n_radial: + for lambda in 0..max_l: + for mu in -lambda..+lambda: + yield (n, lambda, mu) + + rholearn's per-atom layout (same total length): + for lambda in 0..max_l: + for n in 0..n_radial: + for mu in -lambda..+lambda: + yield (lambda, n, mu) + + The permutation is independent of the species (uniform basis). + """ + n_radial = basis_spec.n_radial + max_l = basis_spec.max_l + + # Source flat index for (n, lambda, mu) in OUR layout: + # n * (max_l + 1) ** 2 + lambda * lambda + (mu + lambda) + # (the second-and-third pieces together index the standard Y_lm slot) + def our_idx(n: int, lam: int, mu: int) -> int: + return n * (max_l + 1) ** 2 + lam * lam + (mu + lam) + + # Build the permutation by walking rholearn's order + perm = np.empty(n_radial * (max_l + 1) ** 2, dtype=np.int64) + k = 0 + for lam in range(max_l + 1): + for n in range(n_radial): + for mu in range(-lam, lam + 1): + perm[k] = our_idx(n, lam, mu) + k += 1 + return perm + + +def dense_to_rholearn_flat( + coeffs: np.ndarray, + basis_spec: BasisSpec, + symbols: Iterable[str], +) -> np.ndarray: + """Convert our dense ``(n_atoms, n_coeffs_per_atom)`` coefficients to + rholearn's flat per-structure vector. + + Output length: ``n_atoms * n_coeffs_per_atom``. ``symbols`` is + accepted for API symmetry with the inverse and species-aware + extensions; today the permutation is species-independent because + our BasisSpec is uniform across species. + """ + n_atoms = coeffs.shape[0] + assert coeffs.shape == (n_atoms, basis_spec.n_coeffs_per_atom) + perm = _our_to_rholearn_permutation(basis_spec) + # ``coeffs[:, perm]`` reorders each atom's row from our layout to rholearn's + return coeffs[:, perm].ravel().astype(np.float64) + + +def rholearn_flat_to_dense( + flat: np.ndarray, + basis_spec: BasisSpec, + symbols: Iterable[str], +) -> np.ndarray: + """Inverse of ``dense_to_rholearn_flat``. Returns the dense + ``(n_atoms, n_coeffs_per_atom)`` array. + """ + n_coeffs = basis_spec.n_coeffs_per_atom + if flat.size % n_coeffs != 0: + raise ValueError( + f"flat vector length {flat.size} is not a multiple of " + f"n_coeffs_per_atom={n_coeffs}; cannot reshape to (n_atoms, n_coeffs)" + ) + n_atoms = flat.size // n_coeffs + reshaped = flat.reshape(n_atoms, n_coeffs).astype(np.float64) + # Inverse permutation: ``inv[perm[k]] = k``. + perm = _our_to_rholearn_permutation(basis_spec) + inv = np.empty_like(perm) + inv[perm] = np.arange(perm.size) + return reshaped[:, inv] + + +# --------------------------------------------------------------------------- +# Full TensorMap path +# --------------------------------------------------------------------------- +def dense_to_tensormap( + coeffs: np.ndarray, + basis_spec: BasisSpec, + symbols: Iterable[str], + positions: np.ndarray, + cell: np.ndarray, + structure_idx: int = 0, +): + """Convert dense coefficients to a ``metatensor.TensorMap`` using + rholearn's converter. + + Lazy-imports rholearn + metatensor so this module is importable + without those deps installed (the permutation tests above are + pure numpy). + """ + _ensure_rholearn_importable() + import chemfiles # noqa: F401 (needed by rholearn's converter) + from rholearn.utils import convert # type: ignore[import-not-found] + + flat = dense_to_rholearn_flat(coeffs, basis_spec, symbols) + lmax, nmax = build_lmax_nmax(basis_spec, set(symbols)) + + # Build a chemfiles Frame from the structure (rholearn's converter + # expects one). + frame = chemfiles.Frame() + frame.cell = chemfiles.UnitCell(np.asarray(cell, dtype=np.float64)) + for sym, pos in zip(list(symbols), np.asarray(positions), strict=True): + atom = chemfiles.Atom(sym) + frame.add_atom(atom, list(pos)) + + return convert.coeff_vector_ndarray_to_tensormap( + frame, + coeff_vector=flat, + lmax=lmax, + nmax=nmax, + structure_idx=structure_idx, + tests=0, + ) diff --git a/salted_ft/train_baseline.py b/salted_ft/train_baseline.py new file mode 100644 index 0000000..b5e8dee --- /dev/null +++ b/salted_ft/train_baseline.py @@ -0,0 +1,294 @@ +"""SALTED arm: PyTorch baseline coefficient-prediction model + train loop (D6). + +Path B of the D6 plan: skip the rholearn integration, train a small +SchNet-style invariant message-passing network directly on the D2 +projected coefficients with MSE loss. Produces a checkpoint that +``scripts/density_model_eval.py`` can load and exercise via the +SALTED arm path. + +Architecture is deliberately minimal: + +* Per-atom species embedding (Z -> ``hidden_dim`` vector). +* Gaussian RBF distance featurisation over neighbours within the + ``BasisSpec.cutoff``. +* Two SchNet-style continuous-filter convolution layers. +* Per-atom readout MLP -> ``n_coeffs_per_atom`` channels. + +Notes +----- + +* The output is *invariant* under rotation. The l>0 channels of the + SALTED basis are equivariant by construction, so this baseline + will be systematically wrong on those channels. It still gives a + reasonable scalar density once reconstructed, and is a starting + point for the comparison table. Upgrade to e3nn/MACE for proper + equivariance. +* The dataset reads two parquet directories: D2 source (atom + positions) and D2 projected coefficients (training targets), + joined on ``row_index`` per matching chunk filename. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Iterable + +import ase +import numpy as np +import pyarrow.parquet as pq +import torch +import torch.nn as nn +import torch.nn.functional as F +from ase.neighborlist import primitive_neighbor_list + +from salted_ft.basis import BasisSpec + + +class GaussianRBF(nn.Module): + """Gaussian radial basis expansion of distances.""" + + def __init__(self, n_basis: int = 16, cutoff: float = 4.0, sigma: float = 0.4): + super().__init__() + self.register_buffer("centers", torch.linspace(0.0, cutoff, n_basis)) + self.sigma = sigma + + def forward(self, d: torch.Tensor) -> torch.Tensor: + return torch.exp( + -0.5 * ((d[:, None] - self.centers[None, :]) / self.sigma) ** 2 + ) + + +class CfConv(nn.Module): + """SchNet-style continuous filter convolution.""" + + def __init__(self, hidden_dim: int, n_basis: int): + super().__init__() + self.filter_net = nn.Sequential( + nn.Linear(n_basis, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + self.pre = nn.Linear(hidden_dim, hidden_dim) + self.post = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_rbf: torch.Tensor, + ) -> torch.Tensor: + if edge_index.numel() == 0: + return x + self.post(self.pre(x) * 0) + src, dst = edge_index + msg = self.pre(x)[src] * self.filter_net(edge_rbf) + agg = torch.zeros_like(x) + agg.index_add_(0, dst, msg) + return x + self.post(agg) + + +class SaltedBaselineModel(nn.Module): + """SchNet-style invariant message-passing network for per-atom coefficients.""" + + def __init__( + self, + basis_spec: BasisSpec, + hidden_dim: int = 64, + n_basis: int = 16, + n_layers: int = 2, + max_z: int = 120, + ): + super().__init__() + self.basis_spec = basis_spec + self.cutoff = float(basis_spec.cutoff) + self.z_embed = nn.Embedding(max_z, hidden_dim) + self.rbf = GaussianRBF(n_basis=n_basis, cutoff=self.cutoff) + self.layers = nn.ModuleList( + [CfConv(hidden_dim, n_basis) for _ in range(n_layers)] + ) + self.readout = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.SiLU(), + nn.Linear(hidden_dim, basis_spec.n_coeffs_per_atom), + ) + + def forward(self, atoms: ase.Atoms) -> torch.Tensor: + device = self.z_embed.weight.device + z = torch.from_numpy(atoms.get_atomic_numbers().astype(np.int64)).to(device) + positions = atoms.get_positions().astype(np.float64) + cell = np.asarray(atoms.get_cell(), dtype=np.float64) + pbc = atoms.get_pbc() + + # ASE PBC-aware neighbour list within the cutoff. + # 'ijD' -> source idx, dest idx, displacement vector + i, j, D = primitive_neighbor_list("ijD", pbc, cell, positions, self.cutoff) + if len(i) == 0: + edge_index = torch.zeros((2, 0), dtype=torch.long, device=device) + edge_rbf = torch.zeros((0, self.rbf.centers.numel()), device=device) + else: + edge_index = torch.tensor(np.stack([i, j]), dtype=torch.long, device=device) + dist = torch.tensor( + np.linalg.norm(D, axis=1), dtype=torch.float32, device=device + ) + edge_rbf = self.rbf(dist) + + x = self.z_embed(z) + for layer in self.layers: + x = layer(x, edge_index, edge_rbf) + return self.readout(x) + + +class SaltedTrainingDataset: + """Join D2 source (positions) + projected coefficients (targets) by row_index.""" + + def __init__( + self, + source_dir: str | Path, + coeffs_dir: str | Path, + ): + source_dir = Path(source_dir) + coeffs_dir = Path(coeffs_dir) + + src_files = {p.name: p for p in source_dir.glob("chunk_*.parquet")} + coeffs_files = {p.name: p for p in coeffs_dir.glob("chunk_*.parquet")} + common = sorted(set(src_files) & set(coeffs_files)) + if not common: + raise RuntimeError( + f"No matching chunk_*.parquet in {source_dir} and {coeffs_dir}" + ) + + self._index: list[tuple[str, int]] = [] + for name in common: + n = pq.ParquetFile(coeffs_files[name]).metadata.num_rows + for ri in range(n): + self._index.append((name, ri)) + self._src_files = src_files + self._coeffs_files = coeffs_files + # Per-chunk cache so each parquet is read at most once per worker. + self._src_cache: dict[str, dict] = {} + self._coeffs_cache: dict[str, dict] = {} + + def __len__(self) -> int: + return len(self._index) + + def _load(self, name: str) -> tuple[dict, dict]: + if name not in self._src_cache: + self._src_cache[name] = pq.read_table(self._src_files[name]).to_pydict() + if name not in self._coeffs_cache: + self._coeffs_cache[name] = pq.read_table( + self._coeffs_files[name] + ).to_pydict() + return self._src_cache[name], self._coeffs_cache[name] + + def __getitem__(self, idx: int) -> tuple[ase.Atoms, torch.Tensor]: + name, ri = self._index[idx] + src, coeffs = self._load(name) + # Match by row_index in case projected rows are a subset (D2 skips + # rows with null charge density). + src_row_indices = src["row_index"] + try: + src_ri = src_row_indices.index(coeffs["row_index"][ri]) + except ValueError as err: + raise RuntimeError( + f"Row {ri} of {name} (row_index=" + f"{coeffs['row_index'][ri]}) has no source counterpart" + ) from err + + n_atoms = int(coeffs["n_atoms"][ri]) + positions = np.asarray(src["cartesian_site_positions"][src_ri]).reshape(-1, 3) + cell = np.asarray(src["lattice_vectors"][src_ri]).reshape(3, 3) + Z = np.asarray(coeffs["atomic_numbers"][ri]) + target = np.asarray(coeffs["coefficients"][ri]).reshape(n_atoms, -1) + atoms = ase.Atoms(numbers=Z, positions=positions, cell=cell, pbc=True) + return atoms, torch.from_numpy(target.astype(np.float32)) + + +def train( + source_dir: str | Path, + coeffs_dir: str | Path, + output_ckpt: str | Path, + basis_spec: BasisSpec, + n_epochs: int = 10, + batch_size: int = 8, + learning_rate: float = 1e-3, + device: str = "cpu", + log_every: int = 50, +) -> None: + """Standard PyTorch training loop with gradient accumulation per batch.""" + dataset = SaltedTrainingDataset(source_dir, coeffs_dir) + model = SaltedBaselineModel(basis_spec).to(device) + opt = torch.optim.AdamW(model.parameters(), lr=learning_rate) + + step = 0 + for epoch in range(n_epochs): + order = np.random.permutation(len(dataset)) + for start in range(0, len(order), batch_size): + batch_idx = order[start : start + batch_size] + opt.zero_grad() + losses = [] + for i in batch_idx: + atoms, target = dataset[int(i)] + target = target.to(device) + pred = model(atoms) + loss = F.mse_loss(pred, target) + (loss / len(batch_idx)).backward() + losses.append(loss.item()) + opt.step() + step += 1 + if step % log_every == 0: + mean = float(np.mean(losses)) + print(f"epoch {epoch} step {step} mse {mean:.6f}") + + torch.save( + {"basis_spec": basis_spec, "model": model.state_dict()}, + Path(output_ckpt), + ) + + +def _build_cli() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Train the SALTED baseline model.") + p.add_argument( + "--source-dir", + type=Path, + required=True, + help="D2 input parquet dir (cartesian_site_positions live here).", + ) + p.add_argument( + "--coeffs-dir", + type=Path, + required=True, + help="D2 projected coefficients parquet dir.", + ) + p.add_argument( + "--output-ckpt", + type=Path, + required=True, + help="Path for the trained checkpoint .pt file.", + ) + p.add_argument("--n-epochs", type=int, default=10) + p.add_argument("--batch-size", type=int, default=8) + p.add_argument("--learning-rate", type=float, default=1e-3) + p.add_argument("--device", default="cpu") + return p + + +def main(argv: Iterable[str] | None = None) -> None: + args = _build_cli().parse_args(argv) + train( + source_dir=args.source_dir, + coeffs_dir=args.coeffs_dir, + output_ckpt=args.output_ckpt, + basis_spec=BasisSpec(), + n_epochs=args.n_epochs, + batch_size=args.batch_size, + learning_rate=args.learning_rate, + device=args.device, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/density_model_comparison_table.py b/scripts/density_model_comparison_table.py new file mode 100644 index 0000000..27069f0 --- /dev/null +++ b/scripts/density_model_comparison_table.py @@ -0,0 +1,124 @@ +"""Aggregate D7 per-arm eval outputs into a cross-arm comparison (D8). + +Reads one or more parquet files produced by +``scripts/density_model_eval.py`` and writes: + +* A CSV with one row per arm: ``model``, ``n_structures``, + ``nmape_mean``, ``nmape_std``, ``nmape_median`` and the same for + ``rmse`` / ``nrmse``. +* A GitHub-flavoured markdown table for paste-into-PR consumption. + +Each input parquet may carry rows from one arm (typical) or +multiple arms; rows are grouped by the ``model`` column so it +works either way. Multiple input files for the same arm are +concatenated before aggregation, which is the right behaviour +when a sharded eval run writes per-chunk outputs. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import pandas as pd + + +_METRIC_COLS = ("nmape", "rmse", "nrmse") + + +def aggregate_per_arm(inputs: list[str | Path]) -> pd.DataFrame: + """Concatenate the per-row eval parquets and aggregate per arm. + + Parameters + ---------- + inputs : + Paths to D7-shaped per-row eval parquets. + + Returns + ------- + pd.DataFrame with one row per arm and columns: + ``model``, ``n_structures``, ``{nmape,rmse,nrmse}_{mean,std,median}``. + """ + frames = [pd.read_parquet(p) for p in inputs] + df = pd.concat(frames, ignore_index=True) + + rows = [] + for model_name, group in df.groupby("model", sort=True): + row = {"model": model_name, "n_structures": len(group)} + for metric in _METRIC_COLS: + row[f"{metric}_mean"] = float(group[metric].mean()) + row[f"{metric}_std"] = float(group[metric].std(ddof=0)) + row[f"{metric}_median"] = float(group[metric].median()) + rows.append(row) + return pd.DataFrame(rows) + + +def render_markdown_table(agg: pd.DataFrame) -> str: + """Render the aggregated table as a GitHub-flavoured markdown table. + + Format:: + + | Model | N | NMAPE (%) | RMSE (e/A^3) | NRMSE (%) | + | --- | --- | --- | --- | --- | + | salted | 1500 | 32.10 +/- 8.42 | 0.0120 +/- 0.0050 | 28.70 +/- 7.20 | + """ + header = "| Model | N | NMAPE (%) | RMSE (e/A^3) | NRMSE (%) |" + sep = "| --- | --- | --- | --- | --- |" + lines = [header, sep] + for _, row in agg.iterrows(): + lines.append( + "| {model} | {n} | {nmape:.2f} +/- {nmape_s:.2f} | " + "{rmse:.4f} +/- {rmse_s:.4f} | {nrmse:.2f} +/- {nrmse_s:.2f} |".format( + model=row["model"], + n=int(row["n_structures"]), + nmape=row["nmape_mean"], + nmape_s=row["nmape_std"], + rmse=row["rmse_mean"], + rmse_s=row["rmse_std"], + nrmse=row["nrmse_mean"], + nrmse_s=row["nrmse_std"], + ) + ) + return "\n".join(lines) + "\n" + + +def build_comparison_table( + inputs: list[str | Path], + csv_path: str | Path, + markdown_path: str | Path, +) -> pd.DataFrame: + """End-to-end: aggregate + write CSV and markdown.""" + agg = aggregate_per_arm(inputs) + Path(csv_path).write_text(agg.to_csv(index=False)) + Path(markdown_path).write_text(render_markdown_table(agg)) + return agg + + +def _build_cli() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Aggregate per-arm density eval parquets into a comparison table." + ) + parser.add_argument( + "--inputs", + nargs="+", + type=Path, + required=True, + help="One or more D7-output parquets.", + ) + parser.add_argument("--csv", required=True, type=Path, help="Output CSV path.") + parser.add_argument( + "--markdown", required=True, type=Path, help="Output markdown path." + ) + return parser + + +def main() -> None: + args = _build_cli().parse_args() + agg = build_comparison_table( + inputs=args.inputs, csv_path=args.csv, markdown_path=args.markdown + ) + print(render_markdown_table(agg)) + + +if __name__ == "__main__": + main() diff --git a/scripts/density_model_eval.py b/scripts/density_model_eval.py new file mode 100644 index 0000000..560ee68 --- /dev/null +++ b/scripts/density_model_eval.py @@ -0,0 +1,342 @@ +"""Single-model density evaluation across the LeMat-Rho arms (D7). + +Per-structure evaluator: load a model arm, predict the real-space +density on a regular grid for each test row, and write per-structure +NMAPE / RMSE / NRMSE against the ground-truth density into a +parquet file. Driven from the CLI; importable for D8 (the +comparison-table builder) which calls ``evaluate_dataset`` directly. + +Arm coverage +------------ + +* ``salted`` -- fully wired. Stub mode (no ckpt) is supported via + ``SALTEDModel(basis_spec, ckpt_path=None)``; real mode lands when + D6 (SALTED training driver) produces a checkpoint. +* ``charge3net`` -- grid prediction (probe batching over Nx*Ny*Nz + grid coordinates) lands in D7-beta. Raises NotImplementedError + here so a future user does not silently get stub metrics from a + real-arm name. +* ``deepdft`` -- same as ``charge3net``. + +The Graph2Mat arm is parked (see graph2mat_ft/__init__.py); not +exposed here. +""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import ase +import numpy as np +import pandas as pd + +from salted_ft.basis import BasisSpec + + +def density_nmape(pred: np.ndarray, target: np.ndarray) -> float: + """Integral-normalised MAPE: sum(|target - pred|) / sum(|target|) * 100.""" + return float(np.abs(pred - target).sum() / (np.abs(target).sum() + 1e-10) * 100.0) + + +def density_rmse(pred: np.ndarray, target: np.ndarray) -> float: + """Root mean squared error across all grid points.""" + return float(np.sqrt(((pred - target) ** 2).mean())) + + +def density_nrmse(pred: np.ndarray, target: np.ndarray) -> float: + """RMSE / mean(|target|) * 100. Comparable across electron counts.""" + return float( + np.sqrt(((pred - target) ** 2).mean()) / (np.abs(target).mean() + 1e-10) * 100.0 + ) + + +def predict_density( + model_name: str, + atoms: ase.Atoms, + grid_shape: tuple[int, int, int], + ckpt: str | Path | None, + basis_spec: BasisSpec, + model: object | None = None, + max_probe_batch: int = 2500, +) -> np.ndarray: + """Dispatch to the per-arm grid prediction path. + + Parameters + ---------- + model : + Optional pre-loaded model. If provided, ``ckpt`` is ignored. + Lets tests inject a mock without going through real ckpt loading. + max_probe_batch : + ChargE3Net / DeepDFT probe-batching size. Lower if the device + runs out of memory on big grids. + """ + if model_name == "salted": + # Lazy import: the deepdft / charge3net branches do not need + # rholearn or sibling repos available. + from salted_ft.model import SALTEDModel + + m = model if model is not None else SALTEDModel(basis_spec, ckpt_path=ckpt) + return m.reconstruct_density(atoms, grid_shape) + if model_name == "charge3net": + return _charge3net_predict_grid( + model=model, + ckpt=ckpt, + atoms=atoms, + grid_shape=grid_shape, + max_probe_batch=max_probe_batch, + ) + if model_name == "deepdft": + return _deepdft_predict_grid( + model=model, + ckpt=ckpt, + atoms=atoms, + grid_shape=grid_shape, + max_probe_batch=max_probe_batch, + ) + raise ValueError(f"unknown model arm: {model_name!r}") + + +def _charge3net_predict_grid( + model: object | None, + ckpt: str | Path | None, + atoms: ase.Atoms, + grid_shape: tuple[int, int, int], + max_probe_batch: int, +) -> np.ndarray: + """ChargE3Net grid prediction via probe-batched forward. + + Builds the full-grid graph using charge3net's own + ``KdTreeGraphConstructor`` so atom and probe edges match what + the model saw during training, batches probes through + ``split_batch``, and reshapes to ``(Nx, Ny, Nz)``. + + Loading paths + ------------- + * ``model`` provided: use it directly. The path tests rely on + to mock the network without a real ckpt. + * Else, ``ChargE3NetWrapper(ckpt_path=ckpt)`` is constructed. + Requires the charge3net sibling repo present at + ``../charge3net/`` (resolved by ``charge3net_ft.model``). + """ + import torch + + # Import charge3net_ft.model unconditionally for the sys.path side + # effect (it adds ../charge3net to sys.path so the src.* helpers + # below resolve). When the caller supplies a model directly we still + # need charge3net's data utilities to build the graph. + import charge3net_ft.model as _c3n_wrapper_module # noqa: F401 + + if model is None: + from charge3net_ft.model import ChargE3NetWrapper + + model = ChargE3NetWrapper(ckpt_path=ckpt) + + from src.charge3net.data.collate import collate_list_of_dicts + from src.charge3net.data.graph_construction import KdTreeGraphConstructor + from src.utils.data import calculate_grid_pos + from src.utils.predictions import split_batch + + grid_shape_arr = np.asarray(grid_shape, dtype=np.int64) + dummy_density = np.zeros(tuple(grid_shape_arr), dtype=np.float32) + origin = np.zeros(3, dtype=np.float64) + grid_pos = calculate_grid_pos(dummy_density, origin, atoms.get_cell()) + + constructor = KdTreeGraphConstructor(cutoff=4.0, num_probes=None, disable_pbc=False) + graph_dict = constructor(dummy_density, atoms, grid_pos) + batched = collate_list_of_dicts([graph_dict], pin_memory=False) + + if hasattr(model, "train"): + model.train(False) + + preds: list[torch.Tensor] = [] + with torch.no_grad(): + for sub_batch in split_batch(batched, max_probe_batch): + out = model(sub_batch) + preds.append(out.detach().cpu().squeeze(0)) + + rho_flat = torch.cat(preds, dim=0).numpy() + return rho_flat.reshape(tuple(grid_shape_arr)) + + +def _deepdft_predict_grid( + model: object | None, + ckpt: str | Path | None, + atoms: ase.Atoms, + grid_shape: tuple[int, int, int], + max_probe_batch: int, + num_interactions: int = 3, + node_size: int = 128, + cutoff: float = 4.0, + use_painn: bool = True, +) -> np.ndarray: + """DeepDFT grid prediction via probe-batched forward. + + DeepDFT is the upstream code that ChargE3Net forked, so the + forward input dict shape is identical: same probe_xyz / + probe_edges / num_probes / etc. We reuse charge3net's data + utilities (already imported by ``_charge3net_predict_grid``) + to build the graph. The arm-specific bits are: + + * sys.path side effect from ``deepdft_ft.runner`` (adds + ``../DeepDFT`` and stubs ``asap3`` if it is missing). + * model construction via ``densitymodel.PainnDensityModel`` or + ``densitymodel.DensityModel`` (SchNet variant). + * defaults match ``submit_deepdft_adastra.sh``: + num_interactions=3, node_size=128, cutoff=4.0, PaiNN. + + Loading paths + ------------- + * ``model`` provided: use it directly (tests inject mocks here). + * Else, build the model and ``torch.load`` the ckpt. + """ + import torch + + # sys.path side effect + asap3 stub, must happen before importing + # densitymodel even when caller supplied the model. + import deepdft_ft.runner as _deepdft_runner_module # noqa: F401 + + if model is None: + import densitymodel + + if use_painn: + model = densitymodel.PainnDensityModel(num_interactions, node_size, cutoff) + else: + model = densitymodel.DensityModel(num_interactions, node_size, cutoff) + if ckpt is not None: + state = torch.load(str(ckpt), map_location="cpu", weights_only=False) + # DeepDFT's ckpts wrap the state dict in a "model" key + state_dict = state.get("model", state) + model.load_state_dict(state_dict) + + # Reuse the charge3net data layer (DeepDFT input dict is the same). + import charge3net_ft.model as _c3n_wrapper_module # noqa: F401 + from src.charge3net.data.collate import collate_list_of_dicts + from src.charge3net.data.graph_construction import KdTreeGraphConstructor + from src.utils.data import calculate_grid_pos + from src.utils.predictions import split_batch + + grid_shape_arr = np.asarray(grid_shape, dtype=np.int64) + dummy_density = np.zeros(tuple(grid_shape_arr), dtype=np.float32) + origin = np.zeros(3, dtype=np.float64) + grid_pos = calculate_grid_pos(dummy_density, origin, atoms.get_cell()) + + constructor = KdTreeGraphConstructor( + cutoff=cutoff, num_probes=None, disable_pbc=False + ) + graph_dict = constructor(dummy_density, atoms, grid_pos) + batched = collate_list_of_dicts([graph_dict], pin_memory=False) + + if hasattr(model, "train"): + model.train(False) + + preds: list[torch.Tensor] = [] + with torch.no_grad(): + for sub_batch in split_batch(batched, max_probe_batch): + out = model(sub_batch) + preds.append(out.detach().cpu().squeeze(0)) + + rho_flat = torch.cat(preds, dim=0).numpy() + return rho_flat.reshape(tuple(grid_shape_arr)) + + +def _row_to_atoms(row: pd.Series) -> ase.Atoms: + """Reconstruct an ase.Atoms from a LeMat-Rho-shaped parquet row.""" + positions = np.asarray(row["positions"]).reshape(-1, 3) + cell = np.asarray(row["lattice_vectors"]).reshape(3, 3) + numbers = np.asarray(row["atomic_numbers"]) + return ase.Atoms(numbers=numbers, positions=positions, cell=cell, pbc=True) + + +def _row_target_grid(row: pd.Series) -> tuple[np.ndarray, tuple[int, int, int]]: + grid_shape = tuple(int(x) for x in row["grid_shape"]) + target = np.asarray(row["charge_density"]).reshape(grid_shape) + return target, grid_shape + + +def evaluate_dataset( + model_name: str, + test_parquet: str | Path, + ckpt: str | Path | None, + basis_spec: BasisSpec, + output: str | Path, + limit: int | None = None, +) -> Path: + """Loop over rows in ``test_parquet`` and write per-row metrics.""" + df_in = pd.read_parquet(test_parquet) + if limit is not None: + df_in = df_in.head(limit) + + rows = [] + ckpt_label = str(ckpt) if ckpt is not None else "stub" + for _, row in df_in.iterrows(): + atoms = _row_to_atoms(row) + target, grid_shape = _row_target_grid(row) + pred = predict_density(model_name, atoms, grid_shape, ckpt, basis_spec) + rows.append( + { + "model": model_name, + "ckpt": ckpt_label, + "material_id": row.get("material_id"), + "n_atoms": int(row.get("n_atoms", len(atoms))), + "nmape": density_nmape(pred, target), + "rmse": density_rmse(pred, target), + "nrmse": density_nrmse(pred, target), + } + ) + + out_df = pd.DataFrame(rows) + out_path = Path(output) + out_df.to_parquet(out_path) + return out_path + + +def _build_cli() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Per-structure density-prediction eval for LeMat-Rho arms." + ) + parser.add_argument( + "--model", + required=True, + choices=("salted", "charge3net", "deepdft"), + help="Which arm to evaluate.", + ) + parser.add_argument( + "--test-parquet", + required=True, + type=Path, + help="Path to test split parquet (LeMat-Rho row layout).", + ) + parser.add_argument( + "--output", required=True, type=Path, help="Output parquet path." + ) + parser.add_argument( + "--ckpt", + type=Path, + default=None, + help="Model checkpoint. Omit for stub mode (where supported).", + ) + parser.add_argument( + "--limit", + type=int, + default=None, + help="Evaluate only the first N rows (smoke-test).", + ) + return parser + + +def main() -> None: + args = _build_cli().parse_args() + out_path = evaluate_dataset( + model_name=args.model, + test_parquet=args.test_parquet, + ckpt=args.ckpt, + basis_spec=BasisSpec(), + output=args.output, + limit=args.limit, + ) + print(f"Wrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/scf_speedup_run.py b/scripts/scf_speedup_run.py new file mode 100644 index 0000000..127530a --- /dev/null +++ b/scripts/scf_speedup_run.py @@ -0,0 +1,321 @@ +"""SCF-speedup experiment driver (P4). + +For each row in a held-out test parquet, the driver: + +1. Reconstructs the ``ase.Atoms`` + grid_shape + n_electrons. +2. Predicts the density via the chosen ML arm + (``scripts.density_model_eval.predict_density`` already supports + ``salted``, ``charge3net``, and ``deepdft``). +3. Writes a CHGCAR with VASP's electron-count rescaling so + ``ICHARG=1`` reads a self-consistent total. +4. Builds a paired baseline + predicted Flow via + ``entalsim.dft.scf_speedup.make_scf_speedup_pair`` and submits it + to MongoDB via ``entalsim.core.submit.submit_workflow``. + +The two entalsim callables are dependency-injectable so the driver +unit-tests pass locally without entalsim installed; the CLI imports +them at runtime. +""" + +from __future__ import annotations + +import argparse +import importlib +import json +import logging +import sys +from pathlib import Path +from typing import Any, Callable + +import ase +import numpy as np +import pandas as pd +from pymatgen.io.ase import AseAtomsAdaptor +from tqdm.auto import tqdm + +from salted_ft.basis import BasisSpec +from salted_ft.io import write_chgcar + +logger = logging.getLogger(__name__) + +# scripts/ is not a package; reach the sibling module via sys.path +# (same pattern the test fixture uses). +_SCRIPTS_DIR = Path(__file__).resolve().parent +if str(_SCRIPTS_DIR) not in sys.path: + sys.path.insert(0, str(_SCRIPTS_DIR)) +_density_eval = importlib.import_module("density_model_eval") +predict_density = _density_eval.predict_density + + +_ARMS_REQUIRING_CKPT = ("charge3net", "deepdft") + + +def _row_to_atoms(row: pd.Series) -> ase.Atoms: + positions = np.asarray(row["positions"]).reshape(-1, 3) + cell = np.asarray(row["lattice_vectors"]).reshape(3, 3) + numbers = np.asarray(row["atomic_numbers"]) + return ase.Atoms(numbers=numbers, positions=positions, cell=cell, pbc=True) + + +def _row_grid_shape(row: pd.Series) -> tuple[int, int, int]: + return tuple(int(x) for x in row["grid_shape"]) + + +def _load_submitted_ids(manifest_path: Path, model_name: str) -> set[str]: + """Read a JSONL manifest and return material_ids previously submitted. + + Failed rows (``submitted=False``) are intentionally NOT counted so + the next run retries them. + """ + if not manifest_path.exists(): + return set() + submitted: set[str] = set() + for line in manifest_path.read_text().splitlines(): + line = line.strip() + if not line: + continue + try: + rec = json.loads(line) + except json.JSONDecodeError: + logger.warning("Skipping malformed manifest line: %s", line[:80]) + continue + if rec.get("model") == model_name and rec.get("submitted") is True: + submitted.add(str(rec["material_id"])) + return submitted + + +def run_experiment( + model_name: str, + test_parquet: str | Path, + chgcar_dir: str | Path, + basis_spec: BasisSpec, + project: str, + worker: str, + ckpt: str | Path | None = None, + limit: int | None = None, + dry_run: bool = False, + manifest_path: str | Path | None = None, + skip_existing: bool = False, + make_pair_fn: Callable[..., Any] | None = None, + submit_fn: Callable[..., Any] | None = None, +) -> list[dict[str, Any]]: + """Loop the test parquet and submit one paired Flow per row. + + The driver is resilient to per-row failures: a bad row records + an ``error`` entry and the loop continues. Results stream to a + JSONL manifest after each row so an interrupted run leaves a + resumable record. ``skip_existing=True`` skips rows whose + ``material_id`` is already marked ``submitted=True`` in the + manifest for this ``model_name`` (failed rows are retried). + """ + if model_name in _ARMS_REQUIRING_CKPT and ckpt is None: + raise ValueError( + f"--ckpt is required for arm {model_name!r}; running without " + "weights produces random-init predictions and wastes HPC time. " + "Stub mode is supported only for 'salted'." + ) + + # Lazy-import entalsim callables when the caller did not inject + # mocks. Keeps the test suite passable without entalsim installed. + if make_pair_fn is None: + from entalsim.dft.scf_speedup import make_scf_speedup_pair as make_pair_fn + if submit_fn is None: + from entalsim.core.submit import submit_workflow as submit_fn + + chgcar_root = Path(chgcar_dir) + chgcar_root.mkdir(parents=True, exist_ok=True) + if manifest_path is None: + manifest_path = chgcar_root / "manifest.jsonl" + else: + manifest_path = Path(manifest_path) + manifest_path.parent.mkdir(parents=True, exist_ok=True) + + already_done = ( + _load_submitted_ids(manifest_path, model_name) if skip_existing else set() + ) + if already_done: + logger.info( + "Skipping %d rows already submitted (manifest=%s)", + len(already_done), + manifest_path, + ) + + df_in = pd.read_parquet(test_parquet) + if limit is not None: + df_in = df_in.head(limit) + + ckpt_label = str(ckpt) if ckpt is not None else "stub" + records: list[dict[str, Any]] = [] + + for _, row in tqdm( + df_in.iterrows(), + total=len(df_in), + desc=f"scf_speedup({model_name})", + ): + material_id = str(row["material_id"]) + if material_id in already_done: + logger.info("Skipping %s (already submitted)", material_id) + continue + + record: dict[str, Any] = { + "material_id": material_id, + "model": model_name, + "ckpt": ckpt_label, + "submitted": False, + "error": None, + } + try: + atoms = _row_to_atoms(row) + grid_shape = _row_grid_shape(row) + n_electrons = float(row["n_electrons"]) + + density = predict_density(model_name, atoms, grid_shape, ckpt, basis_spec) + + # One directory per (model, material_id) so make_scf_speedup_pair's + # prev_dir mechanism stages the right file. Nested layout + # (chgcar_root///CHGCAR) avoids ambiguity + # for material_ids that contain separator characters. + row_dir = chgcar_root / model_name / material_id + row_dir.mkdir(parents=True, exist_ok=True) + chgcar_path = row_dir / "CHGCAR" + write_chgcar(density, atoms, chgcar_path, n_electrons=n_electrons) + + structure = AseAtomsAdaptor.get_structure(atoms) + metadata = { + "experiment": "scf_speedup", + "material_id": material_id, + "model": model_name, + "ckpt": ckpt_label, + } + flow = make_pair_fn(structure, row_dir, metadata) + + if not dry_run: + submit_fn(flow, project=project, worker=worker) + + record.update( + { + "chgcar_path": str(chgcar_path), + "n_jobs": len(flow.jobs), + "submitted": not dry_run, + } + ) + logger.info( + "%s arm=%s n_jobs=%d submitted=%s", + material_id, + model_name, + record["n_jobs"], + record["submitted"], + ) + except Exception as exc: # noqa: BLE001 -- isolate per-row failures + # Catch broadly: any per-row exception (corrupt parquet, ML + # OOM, mongo timeout) must not kill the rest of the batch. + record["error"] = repr(exc) + logger.exception( + "Row failed material_id=%s arm=%s: %s", + material_id, + model_name, + exc, + ) + finally: + # Stream to manifest after every row so an interrupted + # run leaves a resumable record. Open in append mode so + # parallel runs (different arms, different parquets) can + # share a manifest if pointed at the same path. + with manifest_path.open("a") as f: + f.write(json.dumps(record) + "\n") + records.append(record) + + return records + + +def _build_cli() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + description="SCF-speedup experiment driver: predict CHGCAR, " + "submit paired r2SCAN single-point Flow per structure." + ) + p.add_argument( + "--model", + required=True, + choices=("salted", "charge3net", "deepdft"), + help="Which ML arm to evaluate.", + ) + p.add_argument( + "--test-parquet", + required=True, + type=Path, + help="Held-out test split parquet (P-ID or P-OOD).", + ) + p.add_argument( + "--chgcar-dir", + required=True, + type=Path, + help="Directory for predicted CHGCAR files; per-row subdirs created.", + ) + p.add_argument( + "--project", + required=True, + help="jobflow_remote project name (matches a jfremote YAML).", + ) + p.add_argument( + "--worker", + required=True, + help="jobflow_remote worker name from the project YAML.", + ) + p.add_argument("--ckpt", type=Path, default=None, help="Model checkpoint path.") + p.add_argument( + "--limit", type=int, default=None, help="Process only the first N rows." + ) + p.add_argument( + "--dry-run", + action="store_true", + help="Write CHGCARs and build Flows but do not submit_workflow.", + ) + p.add_argument( + "--manifest", + type=Path, + default=None, + help="JSONL manifest path (default: /manifest.jsonl). " + "Streamed after each row so interrupted runs are resumable.", + ) + p.add_argument( + "--skip-existing", + action="store_true", + help="Skip rows whose material_id is already submitted=True in the " + "manifest for this model. Failed rows are always retried.", + ) + return p + + +def main(argv: list[str] | None = None) -> None: + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + args = _build_cli().parse_args(argv) + records = run_experiment( + model_name=args.model, + test_parquet=args.test_parquet, + chgcar_dir=args.chgcar_dir, + basis_spec=BasisSpec(), + project=args.project, + worker=args.worker, + ckpt=args.ckpt, + limit=args.limit, + dry_run=args.dry_run, + manifest_path=args.manifest, + skip_existing=args.skip_existing, + ) + submitted = sum(1 for r in records if r["submitted"]) + failed = sum(1 for r in records if r.get("error")) + logger.info( + "Processed %d rows for arm=%s; submitted=%d, failed=%d, dry_run=%s", + len(records), + args.model, + submitted, + failed, + args.dry_run, + ) + + +if __name__ == "__main__": + main() diff --git a/submit_charge3net_adastra.sh b/submit_charge3net_adastra.sh new file mode 100644 index 0000000..f63342a --- /dev/null +++ b/submit_charge3net_adastra.sh @@ -0,0 +1,198 @@ +#!/bin/bash +# ChargE3Net fine-tuning on Adastra (CINES, AMD MI250X), half-node DDP. +# +# Two training modes (select via LEMATRHO_TRAINING_MODE env): +# pretrained (default) — fine-tune from charge3net_mp.pt (MP, 245 epochs) +# from_scratch — train from random init for direct comparison +# +# Env vars: +# LEMATRHO_TRAINING_MODE pretrained | from_scratch (default: pretrained) +# LEMATRHO_ADASTRA_SETUP override $SETUP (default: /lus/scratch/CT10/cad16353/msiron/charge3net_setup) +# LEMATRHO_DRY_RUN 1 to print the resolved train command and exit +# (used by tests/test_submit_script.py) +# +# Submit examples: +# sbatch submit_charge3net_adastra.sh # pretrained +# sbatch --export=ALL,LEMATRHO_TRAINING_MODE=from_scratch submit_charge3net_adastra.sh # from-scratch +# +# Half-node resource layout (g1xxx mi250-shared has 8 GCDs, 128 CPUs, 256 GB): +# - 4 GCDs (gpus-per-node=4) +# - 64 CPUs (16 per task * 4 tasks) +# - 128 GB RAM +# - 4 tasks, one per GCD, for torch DistributedDataParallel +# Effective batch = batch-size * world_size = 16 * 4 = 64 (matches the +# upstream paper's train_mp_e3_final.yaml: batch_size=16, nnodes=2 x nprocs=2). +#SBATCH --job-name=charge3net_ft +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=4 +#SBATCH --account=c1816212 +#SBATCH --constraint=MI250 +#SBATCH --gpus-per-node=4 +#SBATCH --cpus-per-task=16 +# No --mem here on purpose: SLURM allocates memory proportional to our CPU +# share (64 of 128 logical CPUs = ~128 GB out of the 256 GB node). The +# earlier --mem=125000M was being read as "asking for half the node memory" +# and contributed to SLURM auto-bumping us to EXCLUSIVE mode. Letting SLURM +# pick lets the other half of the node stay schedulable for other jobs. +#SBATCH --time=06:00:00 +#SBATCH --output=%x_%j.out +#SBATCH --error=%x_%j.err + +set -eo pipefail + +# --- Paths --- +# Submit dir must be on a scratch with inode headroom (cad16353 currently); the +# account (--account=c1816212 above) handles billing independently. See ADASTRA.md. +SETUP="${LEMATRHO_ADASTRA_SETUP:-/lus/scratch/CT10/cad16353/msiron/charge3net_setup}" +WORK_DIR="$SETUP/LeMat-Rho" +DATA_DIR="$SETUP/charge3net_data" +MP_CKPT="$SETUP/charge3net/models/charge3net_mp.pt" + +# --- Training mode ----------------------------------------------------------- +TRAINING_MODE="${LEMATRHO_TRAINING_MODE:-pretrained}" +case "$TRAINING_MODE" in + pretrained) + CKPT_PATH="$MP_CKPT" + CKPT_DIR="$SETUP/charge3net_checkpoints" + export WANDB_NAME="pretrained_mp" + ;; + from_scratch) + CKPT_PATH="" # no --ckpt-path -> ChargE3NetWrapper inits from random + CKPT_DIR="$SETUP/charge3net_checkpoints_fromscratch" + export WANDB_NAME="from_scratch" + ;; + *) + echo "ERROR: LEMATRHO_TRAINING_MODE must be 'pretrained' or 'from_scratch'," \ + "got '$TRAINING_MODE'" >&2 + exit 2 + ;; +esac + +mkdir -p "$CKPT_DIR" 2>/dev/null || true + +# --- Build train command ----------------------------------------------------- +# Constructed early so LEMATRHO_DRY_RUN can short-circuit before sourcing venv. +TRAIN_ARGS=( + --parquet-dir "$DATA_DIR" + --save-dir "$CKPT_DIR" + --epochs 50 + --batch-size 16 + --lr 5e-4 + --train-probes 200 + --val-probes 1000 + # num-workers=2 (down from 8): with 4 DDP ranks each forking workers, the + # previous setting created 32 worker processes total and the per-worker + # _TABLE_CACHE in data.py OOM-killed jobs 4971293/4971343 at ~140 GB + # cumulative RSS. The LRU eviction we landed in data.py would help on + # its own, but lowering worker count further drops cache pressure with + # zero loss in throughput at this dataset/grid size. + --num-workers 2 + --wandb-project lemat-rho-charge3net + --wandb-entity dtts + --wandb-mode offline +) +if [ -n "$CKPT_PATH" ]; then + TRAIN_ARGS+=(--ckpt-path "$CKPT_PATH") +fi +if [ -f "$CKPT_DIR/latest.pt" ]; then + TRAIN_ARGS+=(--resume-from "$CKPT_DIR/latest.pt") +fi + +if [ "${LEMATRHO_DRY_RUN:-0}" = "1" ]; then + echo "WANDB_NAME=$WANDB_NAME" + echo "TRAINING_MODE=$TRAINING_MODE" + echo "CKPT_DIR=$CKPT_DIR" + printf 'python -m charge3net_ft.train' + for arg in "${TRAIN_ARGS[@]}"; do + printf ' %s' "$arg" + done + printf '\n' + exit 0 +fi + +# --- Environment ------------------------------------------------------------- +# Proxy is required for any outbound HTTP (pip, HF, W&B). Already in ~/.bashrc +# on Adastra but we re-export here so the job script is self contained. +export HTTP_PROXY=http://proxy-l-adastra.cines.fr:3128 +export HTTPS_PROXY=$HTTP_PROXY +export http_proxy=$HTTP_PROXY +export https_proxy=$HTTP_PROXY + +source "$SETUP/venv311/bin/activate" + +export PYTHONPATH="$WORK_DIR:$SETUP/charge3net:$PYTHONPATH" +export PYTHONUNBUFFERED=1 + +# Load W&B key from .env if present. +if [ -f "$WORK_DIR/.env" ]; then + set -a + source "$WORK_DIR/.env" + set +a +fi + +# --- NCCL / DDP reliability tweaks --- +# Job 4977567 (2026-05-21) ran 2h41m, then died from NCCL TCPStore +# "Broken pipe / should dump flag" on the DDP heartbeat. Memory was +# fine (14 GB/task with the LRU cache fix). The crash is on the +# inter-rank communication channel, not the model. These three env +# vars expand the timeout budget so a transient slow rank doesn't +# tear down the whole job. +# NCCL_TIMEOUT per-collective timeout (seconds) +# NCCL_ASYNC_ERROR_HANDLING=1 clean shutdown on rank failure +# (no cascading hangs) +# TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC how long a rank can stall +# before HeartbeatMonitor tears +# down the process group +export NCCL_TIMEOUT=3600 +export NCCL_ASYNC_ERROR_HANDLING=1 +export TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC=1800 +export TORCH_NCCL_TRACE_BUFFER_SIZE=1000 # capture more debug info on next crash + +# --- Distributed-training env vars (read by train.py's _setup_ddp) --- +# SLURM sets SLURM_NTASKS, SLURM_PROCID, SLURM_LOCALID for us via srun. +# torch.distributed wants WORLD_SIZE / RANK / LOCAL_RANK plus MASTER_ADDR +# / MASTER_PORT. We export them once here, srun propagates to each task. +export WORLD_SIZE=$SLURM_NTASKS +export MASTER_ADDR=$(scontrol show hostname "$SLURM_NODELIST" | head -n 1) +export MASTER_PORT=29500 +# RANK / LOCAL_RANK are per-task — set in the wrapper srun command below. + +echo "Node: $(hostname)" +echo "Account: ${SLURM_JOB_ACCOUNT:-unknown}" +echo "Job dir: $WORK_DIR" +echo "Training mode: $TRAINING_MODE (wandb name: $WANDB_NAME)" +echo "Checkpoint dir: $CKPT_DIR" +echo "WORLD_SIZE=$WORLD_SIZE MASTER_ADDR=$MASTER_ADDR MASTER_PORT=$MASTER_PORT" +rocm-smi || true + +python3 -c " +import torch +print(f'torch: {torch.__version__}') +print(f'CUDA/ROCm available: {torch.cuda.is_available()}') +print(f'device count: {torch.cuda.device_count()}') +" + +cd "$WORK_DIR" + +# --- Train ------------------------------------------------------------------ +# srun launches 4 tasks (--ntasks-per-node=4 from #SBATCH). Each task sees +# SLURM_PROCID = global rank, SLURM_LOCALID = local rank within node. +# The TRAIN_ARGS array is exported as a quoted string so the srun-spawned +# bash can reconstruct it. +TRAIN_ARGS_QUOTED="" +for arg in "${TRAIN_ARGS[@]}"; do + TRAIN_ARGS_QUOTED+=" $(printf '%q' "$arg")" +done +export TRAIN_ARGS_QUOTED + +srun --kill-on-bad-exit=1 bash -c ' + export RANK=$SLURM_PROCID + export LOCAL_RANK=$SLURM_LOCALID + # Each task sees ALL 4 GCDs the job was allocated; torch.cuda.set_device(local_rank) + # inside _setup_ddp picks the right one. Restricting visibility per-task here + # would make every task target the same "GCD 0" within its own visibility set. + echo "task RANK=$RANK LOCAL_RANK=$LOCAL_RANK on $(hostname) (will use cuda:$LOCAL_RANK)" + eval "python3 -m charge3net_ft.train $TRAIN_ARGS_QUOTED" +' + +echo "Done. Exit code: $?" diff --git a/submit_deepdft_adastra.sh b/submit_deepdft_adastra.sh new file mode 100644 index 0000000..5fd169c --- /dev/null +++ b/submit_deepdft_adastra.sh @@ -0,0 +1,136 @@ +#!/bin/bash +# DeepDFT training on Adastra (CINES, AMD MI250X), single-GPU paper-faithful. +# +# Faithful to peterbjorgensen/DeepDFT paper settings: +# - 1 GCD (paper used 1x RTX 3090; we use 1x MI250X) +# - batch=2 materials, train=1000 probes/material, val=5000 probes/material +# (hardcoded in deepdft_ft/runner.py, same as upstream) +# - cutoff=4 A, num_interactions=3, node_size=128, PaiNN model +# - max_steps=10,000,000 +# +# Single-GPU keeps the gradient-step semantics identical to the paper. +# DDP code paths in runner.py only fire when WORLD_SIZE>1 -- we leave them +# out here on purpose. If we ever want DDP for DeepDFT we'd also need to +# sweep the LR (effective batch grows with world_size). +# +# Env vars: +# LEMATRHO_ADASTRA_SETUP override $SETUP (default: cad16353 scratch) +# LEMATRHO_DEEPDFT_VARIANT painn (default) | schnet +# LEMATRHO_DRY_RUN 1 to print resolved cmd + exit +# +# Submit examples: +# sbatch submit_deepdft_adastra.sh # PaiNN +# sbatch --export=ALL,LEMATRHO_DEEPDFT_VARIANT=schnet submit_deepdft_adastra.sh # SchNet +# +#SBATCH --job-name=deepdft_ft +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --account=c1816212 +#SBATCH --constraint=MI250 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64000M +#SBATCH --time=24:00:00 +#SBATCH --output=%x_%j.out +#SBATCH --error=%x_%j.err + +set -eo pipefail + +# --- Paths --- +SETUP="${LEMATRHO_ADASTRA_SETUP:-/lus/scratch/CT10/cad16353/msiron/charge3net_setup}" +WORK_DIR="$SETUP/LeMat-Rho" +DATA_DIR="$SETUP/charge3net_data" +DEEPDFT_REPO="$SETUP/DeepDFT" + +# --- Model variant --- +VARIANT="${LEMATRHO_DEEPDFT_VARIANT:-painn}" +case "$VARIANT" in + painn) + EXTRA_ARGS=(--use_painn_model) + OUTPUT_DIR="$SETUP/deepdft_runs/painn" + export WANDB_NAME="deepdft_painn" + ;; + schnet) + EXTRA_ARGS=() # SchNet is the default architecture, no flag needed + OUTPUT_DIR="$SETUP/deepdft_runs/schnet" + export WANDB_NAME="deepdft_schnet" + ;; + *) + echo "ERROR: LEMATRHO_DEEPDFT_VARIANT must be 'painn' or 'schnet', got '$VARIANT'" >&2 + exit 2 + ;; +esac + +mkdir -p "$OUTPUT_DIR" 2>/dev/null || true + +# --- Build train command ----------------------------------------------------- +# Hyperparameters lifted from pretrained_models/{nmc,qm9,ethylenecarbonate}_painn +# in the upstream DeepDFT repo. Same values across all three published checkpoints. +TRAIN_ARGS=( + --dataset "$DATA_DIR" + --output_dir "$OUTPUT_DIR" + --cutoff 4 + --num_interactions 3 + --node_size 128 + --max_steps 10000000 + --device cuda + "${EXTRA_ARGS[@]}" +) +if [ -f "$OUTPUT_DIR/best_model.pth" ]; then + TRAIN_ARGS+=(--load_model "$OUTPUT_DIR/best_model.pth") +fi + +if [ "${LEMATRHO_DRY_RUN:-0}" = "1" ]; then + echo "WANDB_NAME=$WANDB_NAME" + echo "VARIANT=$VARIANT" + echo "OUTPUT_DIR=$OUTPUT_DIR" + printf 'python -m deepdft_ft.runner' + for arg in "${TRAIN_ARGS[@]}"; do + printf ' %s' "$arg" + done + printf '\n' + exit 0 +fi + +# --- Environment ------------------------------------------------------------- +export HTTP_PROXY=http://proxy-l-adastra.cines.fr:3128 +export HTTPS_PROXY=$HTTP_PROXY +export http_proxy=$HTTP_PROXY +export https_proxy=$HTTP_PROXY + +source "$SETUP/venv311/bin/activate" + +export PYTHONPATH="$WORK_DIR:$DEEPDFT_REPO:$PYTHONPATH" +export PYTHONUNBUFFERED=1 + +if [ -f "$WORK_DIR/.env" ]; then + set -a + source "$WORK_DIR/.env" + set +a +fi + +# Pin to GCD 0 (single-GPU paper-faithful). Do NOT set WORLD_SIZE so that +# runner.py's _setup_ddp returns the single-process tuple (0, 0, 1). +export HIP_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=0 + +echo "Node: $(hostname)" +echo "Account: ${SLURM_JOB_ACCOUNT:-unknown}" +echo "Variant: $VARIANT (wandb name: $WANDB_NAME)" +echo "Output dir: $OUTPUT_DIR" +echo "Single-GPU mode (WORLD_SIZE unset)" +rocm-smi || true + +python3 -c " +import torch +print(f'torch: {torch.__version__}') +print(f'CUDA/ROCm available: {torch.cuda.is_available()}') +print(f'device count: {torch.cuda.device_count()}') +" + +cd "$WORK_DIR" + +# --- Train (single GPU, no srun) -------------------------------------------- +python3 -m deepdft_ft.runner "${TRAIN_ARGS[@]}" + +echo "Done. Exit code: $?" diff --git a/submit_project_lematrho_adastra.sh b/submit_project_lematrho_adastra.sh new file mode 100644 index 0000000..dd9e66f --- /dev/null +++ b/submit_project_lematrho_adastra.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Phase D2: project the LeMat-Rho parquet dataset onto the SALTED basis. +# +# One-time CPU job. Reads $SETUP/charge3net_data/chunk_*.parquet, +# writes $SETUP/salted_projected_coefficients/chunk_*.parquet via +# salted_ft.project_dataset (one LSQR per row, ~75 ms per row). +# +# Adastra smoke test (1 chunk, 956 valid rows) timed at 71 s wall. +# Full dataset (69 chunks, ~65k rows) extrapolates to ~80 min. +# Budget 2 h with slack. +# +# Env vars +# LEMATRHO_ADASTRA_SETUP override $SETUP (default: cad16353 scratch) +# +#SBATCH --job-name=salted_project_dataset +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --account=c1816212 +#SBATCH --constraint=GENOA +#SBATCH --cpus-per-task=4 +#SBATCH --time=02:00:00 +#SBATCH --output=%x_%j.out +#SBATCH --error=%x_%j.err +# Resource sizing notes (2026-05-22): +# - --partition=genoa-shared rejected by CINES policy ("You are not allowed +# to ask for a partition"), same as --qos=debug. We use --constraint=GENOA +# and let SLURM auto-route based on resource size. +# - Bumped --cpus-per-task from 16 to 4 so SLURM keeps us in genoa-shared +# (it auto-routes to the shared partition for small CPU asks, exclusive +# for larger ones). 4 CPUs is enough for our numpy-LSQR + BLAS thread +# pool; the projection is ~1 min/chunk, single chunk is the bottleneck. + +set -eo pipefail + +SETUP="${LEMATRHO_ADASTRA_SETUP:-/lus/scratch/CT10/cad16353/msiron/charge3net_setup}" +WORK_DIR="$SETUP/LeMat-Rho" +INPUT_DIR="$SETUP/charge3net_data" +OUTPUT_DIR="$SETUP/salted_projected_coefficients" + +mkdir -p "$OUTPUT_DIR" 2>/dev/null || true + +source "$SETUP/venv311/bin/activate" +export PYTHONPATH="$WORK_DIR:$PYTHONPATH" +export PYTHONUNBUFFERED=1 + +# numpy / lstsq is already multi-threaded via BLAS; cap thread count +# to match the SLURM allocation so we do not oversubscribe the node. +export OMP_NUM_THREADS=$SLURM_CPUS_ON_NODE +export OPENBLAS_NUM_THREADS=$SLURM_CPUS_ON_NODE +export MKL_NUM_THREADS=$SLURM_CPUS_ON_NODE + +echo "Node: $(hostname)" +echo "Account: ${SLURM_JOB_ACCOUNT:-unknown}" +echo "Input: $INPUT_DIR" +echo "Output: $OUTPUT_DIR" +echo "CPUs: $SLURM_CPUS_ON_NODE" + +cd "$WORK_DIR" + +python -m salted_ft.project_dataset \ + --input-dir "$INPUT_DIR" \ + --output-dir "$OUTPUT_DIR" + +echo "Done. Exit code: $?" +echo "Counting output chunks:" +ls "$OUTPUT_DIR"/chunk_*.parquet | wc -l diff --git a/submit_salted_baseline_adastra.sh b/submit_salted_baseline_adastra.sh new file mode 100755 index 0000000..86efb1d --- /dev/null +++ b/submit_salted_baseline_adastra.sh @@ -0,0 +1,83 @@ +#!/bin/bash +# Phase D6 (path B): train the SALTED baseline coefficient-prediction +# model on the D2 projected outputs. +# +# Single-GPU MI250X job. Dataset is the 65k r2SCAN structures with +# their pre-projected per-atom basis coefficients (from D2). Loss is +# MSE on the (n_atoms, 100) coefficient vectors. See +# salted_ft/train_baseline.py for the model architecture (SchNet-style +# invariant message passing, 2 cfconv layers). +# +# Env vars +# LEMATRHO_ADASTRA_SETUP override $SETUP (default: cad16353 scratch) +# LEMATRHO_DRY_RUN 1 to print resolved cmd and exit +# +# Submit: +# sbatch submit_salted_baseline_adastra.sh +# +#SBATCH --job-name=salted_baseline +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --account=c1816212 +#SBATCH --constraint=MI250 +#SBATCH --gpus-per-node=1 +#SBATCH --cpus-per-task=16 +#SBATCH --mem=64000M +#SBATCH --time=24:00:00 +#SBATCH --output=%x_%j.out +#SBATCH --error=%x_%j.err +# +# Resource sizing notes: +# - Single GCD: the baseline model is tiny (~50k params) and +# saturates the per-atom forward path; DDP across multiple GCDs +# would only help if we batched many structures per step, which +# the per-atom variable size makes awkward. Single-GPU is fine. +# - 24h walltime: 10 epochs over 65k rows at ~0.1s/row =~ 2h, plus +# margin for I/O and Adastra cold-start. + +set -eo pipefail + +SETUP="${LEMATRHO_ADASTRA_SETUP:-/lus/scratch/CT10/cad16353/msiron/charge3net_setup}" +WORK_DIR="$SETUP/LeMat-Rho" +SOURCE_DIR="$SETUP/charge3net_data" +COEFFS_DIR="$SETUP/salted_projected_coefficients" +OUTPUT_DIR="$SETUP/salted_baseline_runs" +mkdir -p "$OUTPUT_DIR" +CKPT="$OUTPUT_DIR/salted_baseline_${SLURM_JOB_ID:-local}.pt" + +source "$SETUP/venv311/bin/activate" +export PYTHONPATH="$WORK_DIR:$PYTHONPATH" +export PYTHONUNBUFFERED=1 + +# ROCm visibility (mirrors submit_deepdft_adastra.sh) +export HIP_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=$HIP_VISIBLE_DEVICES + +CMD=(python -m salted_ft.train_baseline + --source-dir "$SOURCE_DIR" + --coeffs-dir "$COEFFS_DIR" + --output-ckpt "$CKPT" + --n-epochs 10 + --batch-size 8 + --learning-rate 1e-3 + --device cuda) + +if [[ "${LEMATRHO_DRY_RUN:-0}" == "1" ]]; then + printf '%s ' "${CMD[@]}" + printf '\n' + exit 0 +fi + +echo "Node: $(hostname)" +echo "Account: ${SLURM_JOB_ACCOUNT:-unknown}" +echo "Source dir: $SOURCE_DIR" +echo "Coeffs dir: $COEFFS_DIR" +echo "Ckpt out: $CKPT" + +cd "$WORK_DIR" + +"${CMD[@]}" + +echo "Done. Exit code: $?" +echo "Wrote: $CKPT" +ls -lh "$CKPT" diff --git a/tests/test_data.py b/tests/test_data.py index 4ef046b..4ec7efc 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -27,9 +27,13 @@ def _import_data_utils(): # Stub out the charge3net modules so the import succeeds without the repo fake_modules = [ - "src", "src.charge3net", "src.charge3net.data", - "src.charge3net.data.collate", "src.charge3net.data.graph_construction", - "src.utils", "src.utils.data", + "src", + "src.charge3net", + "src.charge3net.data", + "src.charge3net.data.collate", + "src.charge3net.data.graph_construction", + "src.utils", + "src.utils.data", ] stubs = {} for mod in fake_modules: @@ -42,6 +46,7 @@ def _import_data_utils(): # Also patch the existence check so it doesn't raise with patch("pathlib.Path.exists", return_value=True): import importlib + # Force reimport with stubs in place if "charge3net_ft.data" in sys.modules: del sys.modules["charge3net_ft.data"] @@ -54,6 +59,7 @@ def test_roundtrip_3d(self): grid = [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]] json_str = json.dumps(grid) from charge3net_ft.data import _parse_grid_json + result = _parse_grid_json(json_str) assert result.shape == (2, 2, 2) assert result.dtype == np.float32 @@ -61,6 +67,7 @@ def test_roundtrip_3d(self): def test_10x10x10(self): from charge3net_ft.data import _parse_grid_json + grid = np.random.rand(10, 10, 10).tolist() result = _parse_grid_json(json.dumps(grid)) assert result.shape == (10, 10, 10) @@ -78,6 +85,7 @@ def _make_row(self): def test_atoms_species(self): import ase from charge3net_ft.data import _row_to_atoms_and_density + row = self._make_row() atoms, density, origin = _row_to_atoms_and_density(row) assert isinstance(atoms, ase.Atoms) @@ -85,21 +93,25 @@ def test_atoms_species(self): def test_pbc(self): from charge3net_ft.data import _row_to_atoms_and_density + atoms, _, _ = _row_to_atoms_and_density(self._make_row()) assert all(atoms.pbc) def test_density_shape(self): from charge3net_ft.data import _row_to_atoms_and_density + _, density, _ = _row_to_atoms_and_density(self._make_row()) assert density.shape == (10, 10, 10) def test_origin_is_zero(self): from charge3net_ft.data import _row_to_atoms_and_density + _, _, origin = _row_to_atoms_and_density(self._make_row()) np.testing.assert_array_equal(origin, [0.0, 0.0, 0.0]) def test_unknown_species_raises(self): from charge3net_ft.data import _row_to_atoms_and_density + row = self._make_row() row["species_at_sites"] = ["Xx"] # invalid symbol with pytest.raises(KeyError): @@ -111,16 +123,24 @@ def _write_chunk(self, path: Path, n_valid: int, n_null: int): """Write a synthetic chunk_*.parquet file.""" valid = [json.dumps(np.ones((10, 10, 10)).tolist())] * n_valid null = [None] * n_null - table = pa.table({ - "compressed_charge_density": pa.array(valid + null, type=pa.string()), - "species_at_sites": pa.array([["Fe"]] * (n_valid + n_null)), - "cartesian_site_positions": pa.array([[[0.0, 0.0, 0.0]]] * (n_valid + n_null)), - "lattice_vectors": pa.array([[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] * (n_valid + n_null)), - }) + table = pa.table( + { + "compressed_charge_density": pa.array(valid + null, type=pa.string()), + "species_at_sites": pa.array([["Fe"]] * (n_valid + n_null)), + "cartesian_site_positions": pa.array( + [[[0.0, 0.0, 0.0]]] * (n_valid + n_null) + ), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] + * (n_valid + n_null) + ), + } + ) pq.write_table(table, path) def test_counts_valid_rows(self): from charge3net_ft.data import _build_parquet_index + with tempfile.TemporaryDirectory() as tmp: d = Path(tmp) self._write_chunk(d / "chunk_000.parquet", n_valid=5, n_null=2) @@ -131,6 +151,7 @@ def test_counts_valid_rows(self): def test_index_entries_reference_correct_file(self): from charge3net_ft.data import _build_parquet_index + with tempfile.TemporaryDirectory() as tmp: d = Path(tmp) self._write_chunk(d / "chunk_000.parquet", n_valid=3, n_null=0) @@ -142,6 +163,177 @@ def test_index_entries_reference_correct_file(self): def test_raises_on_empty_dir(self): from charge3net_ft.data import _build_parquet_index + with tempfile.TemporaryDirectory() as tmp: with pytest.raises(FileNotFoundError): _build_parquet_index(Path(tmp)) + + def test_ignores_extra_columns(self): + """Newer LeMat-Rho dataset versions add Bader-analysis columns (e.g. + bader_charges, bader_volumes) alongside the four required columns. + _build_parquet_index and _row_to_atoms_and_density should ignore the + extras transparently: data.py:46 declares an explicit _COLUMNS allowlist + and pq.read_table is called with columns=_COLUMNS. + """ + from charge3net_ft.data import _build_parquet_index, _row_to_atoms_and_density + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + n = 3 + grid = json.dumps(np.ones((10, 10, 10)).tolist()) + table = pa.table( + { + # required columns + "compressed_charge_density": pa.array([grid] * n, type=pa.string()), + "species_at_sites": pa.array([["Fe"]] * n), + "cartesian_site_positions": pa.array([[[0.0, 0.0, 0.0]]] * n), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] * n + ), + # extras analogous to what Entalpic/lemat-rho-v1 added in 2026: + "bader_charges": pa.array([[0.42]] * n), + "bader_volumes": pa.array([[11.7]] * n), + "material_id": pa.array([f"mat_{i}" for i in range(n)]), + } + ) + pq.write_table(table, d / "chunk_000.parquet") + + # build_parquet_index should still find all 3 valid rows + file_paths, index = _build_parquet_index(d) + assert len(index) == n + assert len(file_paths) == 1 + + # _row_to_atoms_and_density should produce a usable atoms+density + # even when the row dict contains the extras (it indexes the + # required keys directly, so the extras are dead weight). + row = { + "species_at_sites": ["Fe"], + "cartesian_site_positions": [[0.0, 0.0, 0.0]], + "lattice_vectors": [[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]], + "compressed_charge_density": grid, + "bader_charges": [0.42], + "bader_volumes": [11.7], + "material_id": "mat_0", + } + atoms, density, origin = _row_to_atoms_and_density(row) + assert len(atoms) == 1 + assert density.shape == (10, 10, 10) + np.testing.assert_array_equal(origin, np.zeros(3)) + + +# --------------------------------------------------------------------------- +# LRU eviction for the per-worker parquet table cache. +# +# Why this is here (regression test for the OOM that killed jobs 4971293 and +# 4971343): without eviction, each DataLoader worker accumulates every chunk +# it has ever read. With 8 workers per rank x 4 DDP ranks = 32 workers, and +# ~2 GB of pyarrow-decompressed table per chunk, the cache alone can grow to +# ~140 GB on a long run. The OOM hit at MaxRSS=35 GB per rank x 4 = 140 GB, +# above our 125 GB --mem budget. +# +# The fix: cap the cache. A small LRU bounded by `_TABLE_CACHE_MAX_CHUNKS` +# evicts the least-recently-used chunk before adding a new one. +# --------------------------------------------------------------------------- + + +class TestTableCacheLRU: + """LeMatRhoDataset's _TABLE_CACHE must evict to stay below a bounded size.""" + + def _write_n_chunks(self, d: Path, n: int): + for i in range(n): + _write_one_row_chunk(d / f"chunk_{i:03d}.parquet") + + def test_cache_size_is_bounded(self): + """After reading from many chunks, the cache must not contain all of them.""" + from charge3net_ft import data as data_mod + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + n_chunks = 10 + self._write_n_chunks(d, n_chunks) + + # Force a small cap so the test is fast and unambiguous. + original_max = getattr(data_mod, "_TABLE_CACHE_MAX_CHUNKS", None) + data_mod._TABLE_CACHE_MAX_CHUNKS = 3 + data_mod._TABLE_CACHE.clear() + try: + ds = data_mod.LeMatRhoDataset(parquet_dir=d, num_probes=None) + for i in range(len(ds)): + _ = ds._read_row(i) + assert len(data_mod._TABLE_CACHE) <= 3, ( + "cache grew beyond _TABLE_CACHE_MAX_CHUNKS=3; " + f"actual size {len(data_mod._TABLE_CACHE)}" + ) + finally: + if original_max is not None: + data_mod._TABLE_CACHE_MAX_CHUNKS = original_max + data_mod._TABLE_CACHE.clear() + + def test_cache_evicts_least_recently_used(self): + """When the cache is full, the next miss should drop the LRU entry.""" + from charge3net_ft import data as data_mod + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + self._write_n_chunks(d, 5) + data_mod._TABLE_CACHE_MAX_CHUNKS = 2 + data_mod._TABLE_CACHE.clear() + try: + ds = data_mod.LeMatRhoDataset(parquet_dir=d, num_probes=None) + # Touch chunks 0, 1 -> cache holds {0, 1} + ds._read_row(0) + ds._read_row(1) + assert set(data_mod._TABLE_CACHE.keys()) == {0, 1} + # Touch chunk 2 -> the LRU (0) should evict, cache holds {1, 2} + ds._read_row(2) + assert set(data_mod._TABLE_CACHE.keys()) == {1, 2}, ( + f"expected LRU eviction of chunk 0, got cache keys " + f"{set(data_mod._TABLE_CACHE.keys())}" + ) + # Re-access 1 -> bumps 1 to most-recent; cache still {1, 2} + ds._read_row(1) + # Touch 3 -> 2 is now LRU, evict 2, cache holds {1, 3} + ds._read_row(3) + assert set(data_mod._TABLE_CACHE.keys()) == {1, 3}, ( + f"expected LRU eviction of chunk 2 after re-access of 1; " + f"got cache keys {set(data_mod._TABLE_CACHE.keys())}" + ) + finally: + data_mod._TABLE_CACHE.clear() + + def test_cache_max_default_is_reasonable(self): + """The default cap must be > 0 and small enough that 8 workers x cap + worth of cached chunks fits well below per-rank memory budgets. + + With ~2 GB per chunk and ~8 workers per rank, a default of 5 caps + the per-rank cache at ~80 GB worst case (only chunks the worker + actually saw count; in practice well under). We pick 5 to leave + plenty of margin under a 32-GB-per-rank shared-mode allocation. + """ + from charge3net_ft import data as data_mod + + assert hasattr(data_mod, "_TABLE_CACHE_MAX_CHUNKS"), ( + "_TABLE_CACHE_MAX_CHUNKS must be defined for the LRU to work" + ) + assert 1 <= data_mod._TABLE_CACHE_MAX_CHUNKS <= 20, ( + f"_TABLE_CACHE_MAX_CHUNKS={data_mod._TABLE_CACHE_MAX_CHUNKS} is " + "outside the sensible range [1, 20]; very small evicts too " + "aggressively for shuffled access, very large defeats the cap" + ) + + +def _write_one_row_chunk(path: Path): + """Helper: one valid row per chunk; used by the LRU eviction tests.""" + table = pa.table( + { + "compressed_charge_density": pa.array( + [json.dumps(np.ones((10, 10, 10)).tolist())], type=pa.string() + ), + "species_at_sites": pa.array([["Fe"]]), + "cartesian_site_positions": pa.array([[[0.0, 0.0, 0.0]]]), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] + ), + } + ) + pq.write_table(table, path) diff --git a/tests/test_deepdft_data.py b/tests/test_deepdft_data.py new file mode 100644 index 0000000..57ae3dc --- /dev/null +++ b/tests/test_deepdft_data.py @@ -0,0 +1,194 @@ +"""TDD tests for the LeMat-Rho → DeepDFT data adapter. + +DeepDFT (peterbjorgensen/DeepDFT) consumes a per-sample dict of the form:: + + { + "density": np.ndarray (Nx, Ny, Nz), + "atoms": ase.Atoms, + "origin": np.ndarray (3,), + "grid_position": np.ndarray (Nx, Ny, Nz, 3), + "metadata": dict, # must contain "filename" + } + +Our adapter ``LeMatRhoDeepDFTDataset`` reuses the existing +``_row_to_atoms_and_density`` and ``_build_parquet_index`` helpers in +``charge3net_ft.data`` (so the input pipeline is shared between models) and +returns DeepDFT's dict shape directly. No tar/CHGCAR conversion needed. +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path + +import ase +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + + +# --------------------------------------------------------------------------- +# Helpers — write a synthetic chunk_*.parquet with the same schema the real +# LeMat-Rho data has, plus the Bader columns it gained in v1. +# --------------------------------------------------------------------------- +def _write_synthetic_chunk(path: Path, n_valid: int = 3) -> None: + grid = json.dumps(np.ones((10, 10, 10), dtype=np.float32).tolist()) + table = pa.table( + { + "compressed_charge_density": pa.array([grid] * n_valid, type=pa.string()), + "species_at_sites": pa.array([["Fe"]] * n_valid), + "cartesian_site_positions": pa.array([[[0.0, 0.0, 0.0]]] * n_valid), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] * n_valid + ), + # extras DeepDFT must ignore + "bader_charges": pa.array([[0.42]] * n_valid), + "material_id": pa.array([f"mat_{i}" for i in range(n_valid)]), + } + ) + pq.write_table(table, path) + + +class TestLeMatRhoDeepDFTDataset: + """Adapter __getitem__ returns DeepDFT's exact dict contract.""" + + def test_length_matches_valid_rows(self): + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=5) + _write_synthetic_chunk(d / "chunk_001.parquet", n_valid=3) + ds = LeMatRhoDeepDFTDataset(parquet_dir=d) + assert len(ds) == 8 + + def test_item_has_all_required_keys(self): + """DeepDFT's collate_fn reads density, atoms, origin, grid_position, metadata.""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + ds = LeMatRhoDeepDFTDataset(parquet_dir=d) + sample = ds[0] + for key in ("density", "atoms", "origin", "grid_position", "metadata"): + assert key in sample, ( + f"DeepDFT expects key {key!r}; got {list(sample.keys())}" + ) + + def test_item_density_is_3d_numpy(self): + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + assert isinstance(sample["density"], np.ndarray) + assert sample["density"].shape == (10, 10, 10), ( + f"expected (10, 10, 10) density; got {sample['density'].shape}" + ) + + def test_item_atoms_is_ase_atoms_with_pbc(self): + """Periodic boundary conditions matter for any solid-state density.""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + assert isinstance(sample["atoms"], ase.Atoms) + assert all(sample["atoms"].pbc), ( + "LeMat-Rho cells are periodic; atoms.pbc must be (True, True, True)" + ) + + def test_item_origin_is_3vec_zeros(self): + """LeMat-Rho stores grids at fractional (0, 0, 0); the adapter mirrors that.""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + assert isinstance(sample["origin"], np.ndarray) + np.testing.assert_array_equal(sample["origin"], np.zeros(3)) + + def test_item_grid_position_shape_matches_density(self): + """grid_position is (Nx, Ny, Nz, 3) Cartesian probe coordinates.""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + assert sample["grid_position"].shape == (10, 10, 10, 3), ( + f"grid_position must be (Nx, Ny, Nz, 3); got {sample['grid_position'].shape}" + ) + + def test_grid_position_origin_is_zero(self): + """grid_position[0, 0, 0] must be the cell origin (0, 0, 0).""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + np.testing.assert_allclose(sample["grid_position"][0, 0, 0], np.zeros(3)) + + def test_grid_position_uses_cell_matrix(self): + """grid_position[1, 0, 0] should be one step along the a vector. + + For our synthetic 10×10×10 grid with a 4-Å cubic cell: + frac coord at index (1, 0, 0) = (1/10, 0, 0) + Cartesian = frac @ cell = (4/10, 0, 0) = (0.4, 0, 0) + """ + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + np.testing.assert_allclose( + sample["grid_position"][1, 0, 0], [0.4, 0.0, 0.0], atol=1e-5 + ) + + def test_item_metadata_has_filename(self): + """DeepDFT logs reference filename — must be a stable string per sample.""" + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=2) + ds = LeMatRhoDeepDFTDataset(parquet_dir=d) + for i in range(len(ds)): + meta = ds[i]["metadata"] + assert "filename" in meta, f"metadata missing 'filename'; got {meta}" + assert isinstance(meta["filename"], str) + # Filenames should differ across samples so DeepDFT logs don't collide. + assert ds[0]["metadata"]["filename"] != ds[1]["metadata"]["filename"] + + def test_ignores_extra_columns(self): + """Bader / material_id columns added to LeMat-Rho v1 are dead weight here. + + Same regression we already pinned for charge3net_ft.data; mirroring it + on the DeepDFT path keeps the two adapters honest in lockstep. + """ + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "chunk_000.parquet", n_valid=1) + sample = LeMatRhoDeepDFTDataset(parquet_dir=d)[0] + # The synthetic chunk includes bader_charges + material_id columns. + # The adapter should successfully ingest the row regardless. + assert sample["density"].shape == (10, 10, 10) + + +class TestRaisesOnEmptyDir: + def test_no_chunks_in_dir_raises(self): + from deepdft_ft.data import LeMatRhoDeepDFTDataset + + with tempfile.TemporaryDirectory() as tmp: + with pytest.raises(FileNotFoundError): + LeMatRhoDeepDFTDataset(parquet_dir=Path(tmp)) diff --git a/tests/test_density_model_comparison.py b/tests/test_density_model_comparison.py new file mode 100644 index 0000000..b53a194 --- /dev/null +++ b/tests/test_density_model_comparison.py @@ -0,0 +1,149 @@ +"""TDD tests for ``scripts/density_model_comparison_table.py`` (D8). + +Takes the per-arm parquet outputs from D7 and aggregates into a +single comparison table (markdown + CSV). Per-row metrics are +summarised per arm: mean +/- std and median, with the number of +structures evaluated. +""" + +from __future__ import annotations + +import importlib +import sys +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + + +@pytest.fixture +def comparison_module(): + """Import scripts.density_model_comparison_table.""" + scripts_dir = Path(__file__).resolve().parent.parent / "scripts" + if str(scripts_dir) not in sys.path: + sys.path.insert(0, str(scripts_dir)) + if "density_model_comparison_table" in sys.modules: + del sys.modules["density_model_comparison_table"] + return importlib.import_module("density_model_comparison_table") + + +def _toy_eval_parquet( + tmp_path: Path, + model_name: str, + nmape_values: list[float], + rmse_values: list[float], + nrmse_values: list[float], +) -> Path: + """Write a D7-shaped eval-output parquet with known metric values.""" + df = pd.DataFrame( + { + "model": model_name, + "ckpt": "stub", + "material_id": [f"mp-{model_name}-{i}" for i in range(len(nmape_values))], + "n_atoms": 2, + "nmape": nmape_values, + "rmse": rmse_values, + "nrmse": nrmse_values, + } + ) + out = tmp_path / f"eval_{model_name}.parquet" + df.to_parquet(out) + return out + + +class TestAggregate: + def test_returns_one_row_per_arm(self, tmp_path, comparison_module): + p1 = _toy_eval_parquet( + tmp_path, "salted", [10.0, 20.0], [0.1, 0.2], [5.0, 10.0] + ) + p2 = _toy_eval_parquet( + tmp_path, "charge3net", [5.0, 7.0], [0.05, 0.07], [2.0, 3.0] + ) + df = comparison_module.aggregate_per_arm([p1, p2]) + assert set(df["model"]) == {"salted", "charge3net"} + + def test_mean_nmape_matches_input(self, tmp_path, comparison_module): + p = _toy_eval_parquet(tmp_path, "salted", [10.0, 30.0], [0.1, 0.3], [5.0, 15.0]) + df = comparison_module.aggregate_per_arm([p]) + row = df.iloc[0] + assert row["nmape_mean"] == pytest.approx(20.0) + assert row["rmse_mean"] == pytest.approx(0.2) + assert row["nrmse_mean"] == pytest.approx(10.0) + + def test_std_present(self, tmp_path, comparison_module): + p = _toy_eval_parquet(tmp_path, "salted", [10.0, 30.0], [0.1, 0.3], [5.0, 15.0]) + df = comparison_module.aggregate_per_arm([p]) + for col in ("nmape_std", "rmse_std", "nrmse_std"): + assert col in df.columns + assert np.isfinite(df[col].iloc[0]) + + def test_median_present(self, tmp_path, comparison_module): + p = _toy_eval_parquet( + tmp_path, "salted", [10.0, 20.0, 30.0], [0.1, 0.2, 0.3], [5.0, 10.0, 15.0] + ) + df = comparison_module.aggregate_per_arm([p]) + assert df["nmape_median"].iloc[0] == pytest.approx(20.0) + + def test_n_structures_counts_rows(self, tmp_path, comparison_module): + p = _toy_eval_parquet(tmp_path, "salted", [1.0, 2.0, 3.0], [0.1] * 3, [1.0] * 3) + df = comparison_module.aggregate_per_arm([p]) + assert df["n_structures"].iloc[0] == 3 + + def test_aggregates_multiple_files_per_arm(self, tmp_path, comparison_module): + """If the same arm is split across two parquets, aggregate + should treat them as one group. Useful when sharded eval + runs write per-chunk outputs.""" + (tmp_path / "p1").mkdir() + (tmp_path / "p2").mkdir() + p1 = _toy_eval_parquet(tmp_path / "p1", "salted", [10.0], [0.1], [5.0]) + p2 = _toy_eval_parquet(tmp_path / "p2", "salted", [30.0], [0.3], [15.0]) + df = comparison_module.aggregate_per_arm([p1, p2]) + assert len(df) == 1 + assert df.iloc[0]["n_structures"] == 2 + assert df.iloc[0]["nmape_mean"] == pytest.approx(20.0) + + +class TestRenderMarkdown: + def test_markdown_contains_arm_names(self, tmp_path, comparison_module): + p1 = _toy_eval_parquet( + tmp_path, "salted", [10.0, 20.0], [0.1, 0.2], [5.0, 10.0] + ) + p2 = _toy_eval_parquet( + tmp_path, "charge3net", [5.0, 7.0], [0.05, 0.07], [2.0, 3.0] + ) + df = comparison_module.aggregate_per_arm([p1, p2]) + md = comparison_module.render_markdown_table(df) + assert "salted" in md + assert "charge3net" in md + + def test_markdown_has_header_row(self, tmp_path, comparison_module): + p = _toy_eval_parquet(tmp_path, "salted", [10.0], [0.1], [5.0]) + df = comparison_module.aggregate_per_arm([p]) + md = comparison_module.render_markdown_table(df) + # GitHub-flavored markdown table separator + assert "|" in md + assert "---" in md + + +class TestWriteOutputs: + def test_writes_csv(self, tmp_path, comparison_module): + p = _toy_eval_parquet(tmp_path, "salted", [10.0, 20.0], [0.1, 0.2], [5.0, 10.0]) + out_csv = tmp_path / "out.csv" + out_md = tmp_path / "out.md" + comparison_module.build_comparison_table( + inputs=[p], csv_path=out_csv, markdown_path=out_md + ) + assert out_csv.exists() + df = pd.read_csv(out_csv) + assert "model" in df.columns + + def test_writes_markdown(self, tmp_path, comparison_module): + p = _toy_eval_parquet(tmp_path, "salted", [10.0, 20.0], [0.1, 0.2], [5.0, 10.0]) + out_csv = tmp_path / "out.csv" + out_md = tmp_path / "out.md" + comparison_module.build_comparison_table( + inputs=[p], csv_path=out_csv, markdown_path=out_md + ) + assert out_md.exists() + assert "salted" in out_md.read_text() diff --git a/tests/test_density_model_eval.py b/tests/test_density_model_eval.py new file mode 100644 index 0000000..2b375cd --- /dev/null +++ b/tests/test_density_model_eval.py @@ -0,0 +1,347 @@ +"""TDD tests for ``scripts/density_model_eval.py`` (D7). + +Per-structure density evaluation across the LeMat-Rho arms. This +test exercises the SALTED stub path end-to-end (synthesize a tiny +parquet, run the eval, read back the result) and the structural +contract of the arm dispatcher. + +ChargE3Net and DeepDFT grid prediction lands in D7-beta (probe +batching); the eval script must raise NotImplementedError for them +rather than silently fall back to stubs, so a future user does not +get fake metrics on real arms. +""" + +from __future__ import annotations + +import importlib +import sys +from pathlib import Path + +import ase +import numpy as np +import pandas as pd +import pytest + + +@pytest.fixture +def eval_module(): + """Import scripts.density_model_eval, adding scripts/ to sys.path.""" + scripts_dir = Path(__file__).resolve().parent.parent / "scripts" + if str(scripts_dir) not in sys.path: + sys.path.insert(0, str(scripts_dir)) + if "density_model_eval" in sys.modules: + del sys.modules["density_model_eval"] + return importlib.import_module("density_model_eval") + + +def _toy_parquet(tmp_path: Path, n_rows: int = 2) -> Path: + """Synthesise a tiny LeMat-Rho-shaped parquet for eval tests. + + Layout matches the columns ``salted_ft.project_dataset`` writes + plus a ``charge_density`` grid and ``grid_shape`` (the eval is + grid-comparison so we need ground-truth grids).""" + rng = np.random.default_rng(0) + rows = [] + for i in range(n_rows): + grid_shape = (4, 4, 4) + rows.append( + { + "row_index": i, + "material_id": f"mp-toy-{i}", + "n_atoms": 2, + "atomic_numbers": np.array([1, 1], dtype=np.int64), + "positions": np.array( + [[0.0, 0.0, 0.0], [0.74 + 0.01 * i, 0.0, 0.0]], dtype=np.float64 + ).reshape(-1), + "lattice_vectors": (np.eye(3) * 5.0).reshape(-1), + "charge_density": rng.standard_normal(np.prod(grid_shape)).astype( + np.float64 + ), + "grid_shape": np.array(grid_shape, dtype=np.int64), + } + ) + df = pd.DataFrame(rows) + out = tmp_path / "toy_test.parquet" + df.to_parquet(out) + return out + + +class TestMetrics: + def test_nmape_perfect_prediction_is_zero(self, eval_module): + rho = np.array([1.0, 2.0, 3.0]) + assert eval_module.density_nmape(rho, rho) == pytest.approx(0.0) + + def test_nmape_zero_prediction_against_unit_target(self, eval_module): + pred = np.zeros(4) + target = np.ones(4) + # NMAPE = sum(|0 - 1|) / sum(|1|) * 100 = 4 / 4 * 100 = 100 + assert eval_module.density_nmape(pred, target) == pytest.approx(100.0) + + def test_rmse_perfect_prediction_is_zero(self, eval_module): + rho = np.array([1.0, 2.0]) + assert eval_module.density_rmse(rho, rho) == pytest.approx(0.0) + + def test_rmse_known(self, eval_module): + pred = np.array([0.0, 0.0]) + target = np.array([3.0, 4.0]) # MSE = (9+16)/2 = 12.5, RMSE = sqrt(12.5) + assert eval_module.density_rmse(pred, target) == pytest.approx(np.sqrt(12.5)) + + def test_nrmse_perfect_prediction_is_zero(self, eval_module): + rho = np.array([1.0, 2.0]) + assert eval_module.density_nrmse(rho, rho) == pytest.approx(0.0) + + def test_metrics_handle_3d_grids(self, eval_module): + """Metrics must work on (Nx, Ny, Nz) arrays, not just flat.""" + rng = np.random.default_rng(1) + pred = rng.standard_normal((4, 4, 4)) + target = rng.standard_normal((4, 4, 4)) + # Should not error and should be finite + for fn in ( + eval_module.density_nmape, + eval_module.density_rmse, + eval_module.density_nrmse, + ): + assert np.isfinite(fn(pred, target)) + + +class TestPredictDensity: + """Per-arm dispatcher contract.""" + + def test_salted_stub_returns_grid_of_correct_shape(self, eval_module): + from salted_ft.basis import BasisSpec + + atoms = ase.Atoms( + "HH", + positions=[[0, 0, 0], [0.74, 0, 0]], + cell=np.eye(3) * 5.0, + pbc=True, + ) + grid_shape = (6, 6, 6) + rho = eval_module.predict_density( + "salted", atoms, grid_shape, None, BasisSpec() + ) + assert rho.shape == grid_shape + + def test_charge3net_with_mock_model_returns_grid(self, eval_module): + """Charge3Net dispatcher must build the input dict, batch probes, + and reshape to grid. We mock the network with a callable that + returns ones at every probe so we can pin the shape contract + and the reshape order without a real ckpt.""" + pytest.importorskip("torch") + import torch + + from salted_ft.basis import BasisSpec + + if not (Path(__file__).resolve().parent.parent.parent / "charge3net").exists(): + pytest.skip("charge3net sibling repo not present; integration only") + + atoms = ase.Atoms( + "HH", + positions=[[0, 0, 0], [0.74, 0, 0]], + cell=np.eye(3) * 5.0, + pbc=True, + ) + + class MockModel: + calls = 0 + + def train(self, mode): # noqa: ARG002 -- ignored, present for parity + return self + + def __call__(self, sub_batch): + MockModel.calls += 1 + n = int(sub_batch["num_probes"].item()) + # Charge3net returns shape [B=1, n_probes] + return torch.ones((1, n), dtype=torch.float32) + + grid_shape = (6, 6, 6) + rho = eval_module.predict_density( + "charge3net", + atoms, + grid_shape, + None, + BasisSpec(), + model=MockModel(), + max_probe_batch=64, + ) + assert rho.shape == grid_shape + np.testing.assert_array_equal(rho, np.ones(grid_shape, dtype=np.float32)) + # 6^3 = 216 probes, max_probe_batch=64 -> at least 3 forward calls + assert MockModel.calls >= 3 + + def test_charge3net_max_probe_batch_controls_chunking(self, eval_module): + """Lowering max_probe_batch must increase the number of forward + passes proportionally.""" + pytest.importorskip("torch") + import torch + + from salted_ft.basis import BasisSpec + + if not (Path(__file__).resolve().parent.parent.parent / "charge3net").exists(): + pytest.skip("charge3net sibling repo not present") + + atoms = ase.Atoms( + "HH", + positions=[[0, 0, 0], [0.74, 0, 0]], + cell=np.eye(3) * 5.0, + pbc=True, + ) + + class CountingMock: + def __init__(self): + self.calls = 0 + + def train(self, mode): # noqa: ARG002 + return self + + def __call__(self, sub_batch): + self.calls += 1 + n = int(sub_batch["num_probes"].item()) + return torch.zeros((1, n), dtype=torch.float32) + + m1 = CountingMock() + eval_module.predict_density( + "charge3net", + atoms, + (8, 8, 8), + None, + BasisSpec(), + model=m1, + max_probe_batch=512, + ) + m2 = CountingMock() + eval_module.predict_density( + "charge3net", + atoms, + (8, 8, 8), + None, + BasisSpec(), + model=m2, + max_probe_batch=32, + ) + # Smaller batch -> more sub-batches + assert m2.calls > m1.calls + + def test_deepdft_with_mock_model_returns_grid(self, eval_module): + """DeepDFT shares ChargE3Net's input dict format (the latter was + forked from the former), so the dispatcher should reuse the same + probe-batching machinery with a DeepDFT-built model. Mock model + pins the shape contract.""" + pytest.importorskip("torch") + import torch + + from salted_ft.basis import BasisSpec + + # DeepDFT sibling repo is required because the dispatcher's + # sys.path side effect goes through deepdft_ft.runner. + if not (Path(__file__).resolve().parent.parent.parent / "DeepDFT").exists(): + pytest.skip("DeepDFT sibling repo not present; integration only") + if not (Path(__file__).resolve().parent.parent.parent / "charge3net").exists(): + pytest.skip("charge3net sibling repo not present") + + atoms = ase.Atoms( + "HH", + positions=[[0, 0, 0], [0.74, 0, 0]], + cell=np.eye(3) * 5.0, + pbc=True, + ) + + class DeepDFTMock: + def train(self, mode): # noqa: ARG002 + return self + + def __call__(self, sub_batch): + n = int(sub_batch["num_probes"].item()) + return torch.full((1, n), 0.5, dtype=torch.float32) + + grid_shape = (4, 4, 4) + rho = eval_module.predict_density( + "deepdft", + atoms, + grid_shape, + None, + BasisSpec(), + model=DeepDFTMock(), + max_probe_batch=32, + ) + assert rho.shape == grid_shape + np.testing.assert_allclose(rho, np.full(grid_shape, 0.5, dtype=np.float32)) + + def test_unknown_arm_raises_value_error(self, eval_module): + from salted_ft.basis import BasisSpec + + atoms = ase.Atoms( + "HH", positions=[[0, 0, 0], [0.74, 0, 0]], cell=np.eye(3) * 5.0, pbc=True + ) + with pytest.raises(ValueError, match="unknown"): + eval_module.predict_density("bogus", atoms, (6, 6, 6), None, BasisSpec()) + + +class TestEvaluateDataset: + def test_writes_parquet_with_per_row_metrics(self, tmp_path, eval_module): + from salted_ft.basis import BasisSpec + + in_path = _toy_parquet(tmp_path, n_rows=2) + out_path = tmp_path / "eval_out.parquet" + eval_module.evaluate_dataset( + model_name="salted", + test_parquet=in_path, + ckpt=None, + basis_spec=BasisSpec(), + output=out_path, + ) + assert out_path.exists() + df = pd.read_parquet(out_path) + assert len(df) == 2 + for col in ("material_id", "nmape", "rmse", "nrmse"): + assert col in df.columns + + def test_metrics_are_finite(self, tmp_path, eval_module): + from salted_ft.basis import BasisSpec + + in_path = _toy_parquet(tmp_path, n_rows=2) + out_path = tmp_path / "eval_out.parquet" + eval_module.evaluate_dataset( + model_name="salted", + test_parquet=in_path, + ckpt=None, + basis_spec=BasisSpec(), + output=out_path, + ) + df = pd.read_parquet(out_path) + for col in ("nmape", "rmse", "nrmse"): + assert np.isfinite(df[col]).all() + + def test_records_model_and_ckpt_in_output(self, tmp_path, eval_module): + """Output rows must carry the arm name + ckpt path so a downstream + comparison table can group by model without re-deriving.""" + from salted_ft.basis import BasisSpec + + in_path = _toy_parquet(tmp_path, n_rows=1) + out_path = tmp_path / "eval_out.parquet" + eval_module.evaluate_dataset( + model_name="salted", + test_parquet=in_path, + ckpt=None, + basis_spec=BasisSpec(), + output=out_path, + ) + df = pd.read_parquet(out_path) + assert (df["model"] == "salted").all() + assert df["ckpt"].iloc[0] in (None, "", "stub") + + def test_limit_caps_n_rows_evaluated(self, tmp_path, eval_module): + from salted_ft.basis import BasisSpec + + in_path = _toy_parquet(tmp_path, n_rows=5) + out_path = tmp_path / "eval_out.parquet" + eval_module.evaluate_dataset( + model_name="salted", + test_parquet=in_path, + ckpt=None, + basis_spec=BasisSpec(), + output=out_path, + limit=2, + ) + df = pd.read_parquet(out_path) + assert len(df) == 2 diff --git a/tests/test_equivariance.py b/tests/test_equivariance.py new file mode 100644 index 0000000..5b21770 --- /dev/null +++ b/tests/test_equivariance.py @@ -0,0 +1,164 @@ +"""Structural equivariance test for ChargE3Net. + +ChargE3Net predicts the scalar charge density ρ(r). For the model to be +rotationally equivariant (i.e. ρ(R·r; R·atoms) == ρ(r; atoms)), the output +irreps of the probe-side network must contain only ℓ=0 even-parity components +("0e", pure scalars). This is the e3nn-level guarantee: as long as the final +representation is a scalar irrep, the model's output is invariant under SO(3) +acting on the input frame. + +A runtime equivariance check (rotate inputs, predict, compare to predictions +on the unrotated inputs) is the gold standard but requires a real forward +pass on the production-sized model, which is too slow for a CPU unit test. +The structural test here covers the same property at the architecture level. + +Skipped automatically when the upstream charge3net repo isn't on disk. +""" + +from __future__ import annotations + +import sys +from pathlib import Path + +import pytest +import torch + +# --------------------------------------------------------------------------- +# Skip if the sibling charge3net repo isn't installed locally +# --------------------------------------------------------------------------- +_CHARGE3NET_ROOT = Path(__file__).resolve().parent.parent.parent / "charge3net" +if not _CHARGE3NET_ROOT.exists(): + pytest.skip( + f"charge3net repo not at {_CHARGE3NET_ROOT}; " + "clone github.com/AIforGreatGood/charge3net there to run this test", + allow_module_level=True, + ) +if str(_CHARGE3NET_ROOT) not in sys.path: + sys.path.insert(0, str(_CHARGE3NET_ROOT)) + +from e3nn import o3 # noqa: E402 +from src.charge3net.models.e3 import E3DensityModel # noqa: E402 + + +@pytest.fixture(scope="module") +def production_model(): + """Build a model with the MP-checkpoint hyperparameters. + + Module-scoped so the (slow) construction happens once for all assertions. + """ + torch.manual_seed(0) + model = E3DensityModel( + num_interactions=3, + num_neighbors=20, + mul=500, + lmax=4, + cutoff=4.0, + basis="gaussian", + num_basis=20, + ) + model.train(False) + return model + + +def test_param_count_matches_mp_checkpoint(production_model): + """Sanity check: the model has the 1.9M params we expect. + + Guards against silently changing the architecture in a way that breaks + checkpoint loading from charge3net_mp.pt. + """ + n_params = sum(p.numel() for p in production_model.parameters()) + assert 1_900_000 <= n_params <= 1_920_000, ( + f"Architecture drift: expected ~1.91M params (MP checkpoint), got {n_params:,}" + ) + + +def test_atom_model_uses_higher_order_irreps(production_model): + """ChargE3Net's atom representation must include ℓ>0 irreps to be 'higher-order'. + + The paper's central claim is that going from ℓ_max=1 to ℓ_max=4 produces + substantially better densities on systems with subtle bonding. If someone + accidentally drops the higher-l components (e.g. by passing lmax=0), the + model degenerates to a scalar-only network and silently regresses to a + much weaker baseline. + """ + atom_irreps = production_model.atom_model.atom_irreps_sequence + assert len(atom_irreps) > 0, "atom_irreps_sequence is empty" + final_irreps = atom_irreps[-1] + max_l = max(ir.l for _mul, ir in final_irreps) + assert max_l >= 4, ( + f"Atom representation max ℓ is {max_l}; ChargE3Net's " + f"higher-order claim requires ℓ_max ≥ 4. Got {final_irreps}." + ) + + +def test_atom_model_has_both_parities(production_model): + """The atom representation should include both even (+) and odd (-) parity irreps. + + Without odd-parity components the model can't represent any vector- or + pseudovector-valued atom features, which the higher-order convolutions + need internally. The default get_irreps(mul, lmax) function in e3.py + generates both; this test pins that down. + """ + final_irreps = production_model.atom_model.atom_irreps_sequence[-1] + parities = {ir.p for _mul, ir in final_irreps} + assert parities == {-1, 1}, ( + f"Atom irreps should include both even (p=+1) and odd (p=-1) parities; " + f"got parities {parities}: {final_irreps}" + ) + + +def test_get_irreps_helper_is_balanced(): + """The get_irreps helper in e3.py should produce roughly balanced channel counts. + + This is the function used to construct atom_irreps. If it ever returns + zero-multiplicity for any (l, p) pair at production hyperparameters, the + architecture breaks silently (some irreps disappear). Tests the helper + directly to fail fast. + """ + from src.charge3net.models.e3 import get_irreps + + irreps = get_irreps(500, lmax=4) + multiplicities = [mul for mul, _ in irreps] + assert all(mul > 0 for mul in multiplicities), ( + f"get_irreps(500, 4) produced a zero-multiplicity irrep: {irreps}" + ) + # 5 ℓ levels × 2 parities = 10 entries + assert len(irreps) == 10, ( + f"Expected 10 irreps (5 ℓ × 2 parity), got {len(irreps)}: {irreps}" + ) + + +def test_atom_irreps_sequence_length_matches_num_interactions(production_model): + """One irreps entry per convolution layer (plus the input embedding).""" + seq = production_model.atom_model.atom_irreps_sequence + # num_interactions=3 → 3 convolutions; the sequence holds the post-conv + # representations. Length will be 3 or 4 depending on whether the input + # embedding is included; both are valid, but we pin a sane range. + assert 3 <= len(seq) <= 5, ( + f"atom_irreps_sequence length {len(seq)} is outside the expected " + f"range [3, 5] for num_interactions=3" + ) + + +def test_atom_model_uses_cutoff_consistent_with_kdtree(production_model): + """The cutoff baked into the atom model must match what the dataset uses. + + `KdTreeGraphConstructor` in LeMatRhoDataset uses cutoff=4.0; if the model + is built with a different cutoff, edges fed in at training time won't + match what the convolution layer expects. + """ + assert production_model.atom_model.cutoff == pytest.approx(4.0) + + +def test_e3nn_o3_irreps_are_proper_objects(production_model): + """The atom representation must use e3nn's o3.Irreps wrapper. + + Equivariance is enforced by the o3.Irreps abstraction (which carries + parity information and is consumed by FullyConnectedTensorProduct). If + someone replaces it with a plain list, equivariance silently breaks even + though the forward pass still produces output. + """ + final_irreps = production_model.atom_model.atom_irreps_sequence[-1] + assert isinstance(final_irreps, o3.Irreps), ( + f"Expected o3.Irreps for atom_irreps_sequence[-1]; got {type(final_irreps)}" + ) diff --git a/tests/test_graph2mat_basis.py b/tests/test_graph2mat_basis.py new file mode 100644 index 0000000..c91c0f2 --- /dev/null +++ b/tests/test_graph2mat_basis.py @@ -0,0 +1,163 @@ +"""TDD tests for the Graph2Mat-arm basis adapter (PR zeta-alpha). + +Wraps our uniform ``salted_ft.basis.BasisSpec`` into Graph2Mat's +``PointBasis`` per-species objects. Graph2Mat expects one +``PointBasis`` per atomic species, each carrying its own basis-size, +cutoff, and basis-convention. Our BasisSpec is species-uniform in v1 +so the adapter expands the same spec across every species in a +structure. + +Locked contracts: + +* ``point_basis_for_species(symbol, basis_spec)`` -> ``PointBasis`` + ``.type == symbol``, ``.R == basis_spec.cutoff``, + ``.basis_size == basis_spec.n_coeffs_per_atom``, + ``.basis_convention == 'spherical'``. + +* ``basis_table_for_species(symbols, basis_spec)`` -> dict + ``{symbol: PointBasis}`` so downstream Graph2Mat data processors + can look up by atomic symbol. + +Graph2Mat 0.0.13 PointBasis API: + + PointBasis( + type: str | int, + R: float | ndarray, + basis: str | Sequence[int | (int, int, int)] = (), + basis_convention: 'cartesian'|'spherical'|'siesta_spherical'|'qe_spherical' = 'spherical', + ) + + When ``basis`` is a sequence of ints, the int at position ``l`` + is the number of radial functions for that angular momentum. + So ``basis=[4, 4, 4, 4, 4]`` is 4 radials at each of l=0..4. + ``basis_size`` is the resulting total number of basis functions + per atom: sum_l (2l + 1) * n_radial[l]. +""" + +from __future__ import annotations + +import pytest + + +class TestPointBasisForSpecies: + def test_returns_pointbasis_instance(self): + pytest.importorskip("graph2mat") + from graph2mat import PointBasis + + from graph2mat_ft.basis import point_basis_for_species + from salted_ft.basis import BasisSpec + + pb = point_basis_for_species("Fe", BasisSpec()) + assert isinstance(pb, PointBasis) + + def test_type_field_is_species_symbol(self): + pytest.importorskip("graph2mat") + from graph2mat_ft.basis import point_basis_for_species + from salted_ft.basis import BasisSpec + + pb = point_basis_for_species("Fe", BasisSpec()) + assert pb.type == "Fe" + + def test_R_matches_basis_spec_cutoff(self): + """Radial cutoff: must equal our BasisSpec.cutoff so the + neighbor structure inside Graph2Mat matches charge3net_ft / + deepdft_ft / salted_ft.""" + pytest.importorskip("graph2mat") + from graph2mat_ft.basis import point_basis_for_species + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + pb = point_basis_for_species("Fe", spec) + assert float(pb.R) == pytest.approx(spec.cutoff) + + def test_basis_size_matches_n_coeffs_per_atom(self): + """The per-atom basis function count Graph2Mat sees must equal + the per-atom coefficient count salted_ft.projection produces. + Mismatch means our projected coefficients couldn't be loaded + into a Graph2Mat density-matrix at all. + """ + pytest.importorskip("graph2mat") + from graph2mat_ft.basis import point_basis_for_species + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + pb = point_basis_for_species("Fe", spec) + assert pb.basis_size == spec.n_coeffs_per_atom + + def test_basis_convention_is_spherical(self): + """Real spherical harmonics. Cartesian would be the wrong basis + for our projected coefficients (we use real Y_lm in + salted_ft.projection._real_sph_harm). + """ + pytest.importorskip("graph2mat") + from graph2mat_ft.basis import point_basis_for_species + from salted_ft.basis import BasisSpec + + pb = point_basis_for_species("Fe", BasisSpec()) + assert pb.basis_convention == "spherical" + + def test_basis_has_one_entry_per_l(self): + """basis is sanitised by Graph2Mat into a tuple of (n_radial, l, parity) + triples. We expect one triple per l in 0..max_l, each with the same + n_radial value matching our uniform spec. + """ + pytest.importorskip("graph2mat") + from graph2mat_ft.basis import point_basis_for_species + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + pb = point_basis_for_species("Fe", spec) + # After PointBasis.__post_init__ sanitisation, .basis is + # tuple[tuple[int, int, int], ...] with one entry per l value. + assert len(pb.basis) == spec.max_l + 1 + for entry in pb.basis: + n_radial, lam, _parity = entry + assert n_radial == spec.n_radial, ( + f"n_radial mismatch at l={lam}: got {n_radial}, want {spec.n_radial}" + ) + + def test_different_species_give_separate_pointbasis(self): + """Same spec, different species type field. Sanity check that + adapter doesn't cache or share across species. + """ + pytest.importorskip("graph2mat") + from graph2mat_ft.basis import point_basis_for_species + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + pb_h = point_basis_for_species("H", spec) + pb_fe = point_basis_for_species("Fe", spec) + assert pb_h.type == "H" and pb_fe.type == "Fe" + # But size + cutoff are the same since the spec is uniform + assert pb_h.basis_size == pb_fe.basis_size + assert float(pb_h.R) == float(pb_fe.R) + + +class TestBasisTableForSpecies: + def test_returns_dict_keyed_by_symbol(self): + pytest.importorskip("graph2mat") + from graph2mat_ft.basis import basis_table_for_species + from salted_ft.basis import BasisSpec + + table = basis_table_for_species(("H", "O", "Fe"), BasisSpec()) + assert set(table) == {"H", "O", "Fe"} + + def test_values_are_pointbasis(self): + pytest.importorskip("graph2mat") + from graph2mat import PointBasis + + from graph2mat_ft.basis import basis_table_for_species + from salted_ft.basis import BasisSpec + + table = basis_table_for_species(("H", "Fe"), BasisSpec()) + for v in table.values(): + assert isinstance(v, PointBasis) + + def test_deduplicates_repeated_species(self): + pytest.importorskip("graph2mat") + from graph2mat_ft.basis import basis_table_for_species + from salted_ft.basis import BasisSpec + + # Repeated species in the input list should collapse to one entry. + table = basis_table_for_species(("Fe", "Fe", "Fe", "O"), BasisSpec()) + assert set(table) == {"Fe", "O"} diff --git a/tests/test_graph2mat_io.py b/tests/test_graph2mat_io.py new file mode 100644 index 0000000..e8878b5 --- /dev/null +++ b/tests/test_graph2mat_io.py @@ -0,0 +1,24 @@ +"""TDD tests for the Graph2Mat IO surface (PR zeta-delta). + +graph2mat_ft.io should expose the same read_chgcar / write_chgcar +helpers as salted_ft.io, sharing a single implementation (no +duplicate code). These tests pin that the re-exports are the +identical callable, so a fix in salted_ft.io automatically +propagates to the Graph2Mat arm. +""" + +from __future__ import annotations + + +def test_read_chgcar_is_reexport(): + from graph2mat_ft.io import read_chgcar as g2m_read + from salted_ft.io import read_chgcar as salted_read + + assert g2m_read is salted_read + + +def test_write_chgcar_is_reexport(): + from graph2mat_ft.io import write_chgcar as g2m_write + from salted_ft.io import write_chgcar as salted_write + + assert g2m_write is salted_write diff --git a/tests/test_graph2mat_model.py b/tests/test_graph2mat_model.py new file mode 100644 index 0000000..4a942dd --- /dev/null +++ b/tests/test_graph2mat_model.py @@ -0,0 +1,146 @@ +"""TDD tests for ``Graph2MatModel`` (PR zeta-gamma). + +Mirrors ``salted_ft.model.SALTEDModel``: a single-call wrapper that +takes an ASE Atoms and returns ``(n_atoms, n_coeffs_per_atom)`` +coefficients. In stub mode (``ckpt_path=None``) the coefficients +are deterministic and seeded from positions / numbers / basis_spec. +The real Graph2Mat forward pass lands in D6 and is asserted here +to raise NotImplementedError until then -- so the failure mode is +loud rather than silently returning stub output. +""" + +from __future__ import annotations + +import ase +import numpy as np +import pytest + + +def _h2_atoms() -> ase.Atoms: + return ase.Atoms( + symbols=("H", "H"), + positions=[[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]], + cell=np.eye(3) * 5.0, + pbc=True, + ) + + +def _feo_atoms() -> ase.Atoms: + return ase.Atoms( + symbols=("Fe", "O"), + positions=[[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]], + cell=np.eye(3) * 4.0, + pbc=True, + ) + + +class TestStubMode: + def test_constructible_without_ckpt(self): + from graph2mat_ft.model import Graph2MatModel + from salted_ft.basis import BasisSpec + + m = Graph2MatModel(BasisSpec()) + assert m.ckpt_path is None + + def test_output_shape(self): + from graph2mat_ft.model import Graph2MatModel + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + m = Graph2MatModel(spec) + out = m(_h2_atoms()) + assert out.shape == (2, spec.n_coeffs_per_atom) + + def test_output_dtype(self): + from graph2mat_ft.model import Graph2MatModel + from salted_ft.basis import BasisSpec + + m = Graph2MatModel(BasisSpec()) + out = m(_h2_atoms()) + assert out.dtype == np.float64 + + def test_output_finite(self): + from graph2mat_ft.model import Graph2MatModel + from salted_ft.basis import BasisSpec + + m = Graph2MatModel(BasisSpec()) + out = m(_feo_atoms()) + assert np.isfinite(out).all() + + def test_deterministic_same_input(self): + """Same atoms in -> same coefficients out. Required for the + downstream evaluation pipeline to be reproducible.""" + from graph2mat_ft.model import Graph2MatModel + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + m = Graph2MatModel(spec) + out1 = m(_h2_atoms()) + out2 = m(_h2_atoms()) + np.testing.assert_array_equal(out1, out2) + + def test_position_dependent(self): + """Different positions -> different coefficients. Catches the + bug where the stub accidentally seeds only on species (which + would make every Fe2O3 polymorph have identical coeffs).""" + from graph2mat_ft.model import Graph2MatModel + from salted_ft.basis import BasisSpec + + m = Graph2MatModel(BasisSpec()) + a = _h2_atoms() + b = _h2_atoms() + b.positions[1, 0] += 0.1 # nudge the second H + out_a = m(a) + out_b = m(b) + assert not np.array_equal(out_a, out_b) + + def test_species_dependent(self): + """Different atomic numbers should change the seed even at + identical positions.""" + from graph2mat_ft.model import Graph2MatModel + from salted_ft.basis import BasisSpec + + m = Graph2MatModel(BasisSpec()) + a = _h2_atoms() + b = _h2_atoms() + b.numbers[1] = 8 # H -> O + out_a = m(a) + out_b = m(b) + assert not np.array_equal(out_a, out_b) + + def test_small_magnitude(self): + """Stub coefficients should be small (order 1e-3) so the + reconstructed densities stay in the regime where downstream + metric tests run without overflow.""" + from graph2mat_ft.model import Graph2MatModel + from salted_ft.basis import BasisSpec + + m = Graph2MatModel(BasisSpec()) + out = m(_h2_atoms()) + assert np.max(np.abs(out)) < 1.0 + + +class TestRealMode: + def test_with_ckpt_raises_until_d6(self): + """Real Graph2Mat forward is deferred to D6. Until then a real + ckpt path must fail loudly rather than silently fall back to + the stub (which would corrupt benchmark results).""" + from graph2mat_ft.model import Graph2MatModel + from salted_ft.basis import BasisSpec + + m = Graph2MatModel(BasisSpec(), ckpt_path="/tmp/fake.ckpt") + with pytest.raises(NotImplementedError): + m(_h2_atoms()) + + +class TestReconstructDensity: + """Convenience helper: predict + reconstruct on a VASP-like grid.""" + + def test_shape_matches_grid(self): + from graph2mat_ft.model import Graph2MatModel + from salted_ft.basis import BasisSpec + + m = Graph2MatModel(BasisSpec()) + grid_shape = (8, 8, 8) + rho = m.reconstruct_density(_h2_atoms(), grid_shape) + assert rho.shape == grid_shape diff --git a/tests/test_graph2mat_projection.py b/tests/test_graph2mat_projection.py new file mode 100644 index 0000000..3a02bcd --- /dev/null +++ b/tests/test_graph2mat_projection.py @@ -0,0 +1,215 @@ +"""TDD tests for the Graph2Mat coefficient projection (PR zeta-beta). + +Path A of the Graph2Mat arm: we keep the same regression target as +SALTED (per-atom basis coefficient vectors from +``salted_ft.projection``) and only ask Graph2Mat for a different +backbone. So the "projection" here is a layout transform, not a +basis change. + +Layout we map between: + +* dense ``coeffs[N_atoms, n_coeffs_per_atom]`` -- what + ``salted_ft.project_chgcar_to_basis`` returns +* flat ``point_labels[N_atoms * n_coeffs_per_atom]`` -- atom-major + concatenation, the shape Graph2Mat's per-node targets take + when every node has the same uniform basis + +Per-atom blocks are kept *contiguous* and *in input order* so the +flat vector lines up with the graph node order Graph2Mat builds +from the structure. + +These tests pin the pack/unpack roundtrip and order contract -- +they do not exercise Graph2Mat's matrix machinery (we do not have +off-site coefficients in v1). +""" + +from __future__ import annotations + +import numpy as np +import pytest + + +class TestPackCoeffsToPointLabels: + def test_output_shape(self): + from graph2mat_ft.projection import pack_coeffs_to_point_labels + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + coeffs = np.zeros((3, spec.n_coeffs_per_atom)) + flat = pack_coeffs_to_point_labels(coeffs, spec, ("Fe", "O", "H")) + assert flat.shape == (3 * spec.n_coeffs_per_atom,) + + def test_dtype_preserved(self): + from graph2mat_ft.projection import pack_coeffs_to_point_labels + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + rng = np.random.default_rng(0) + coeffs = rng.standard_normal((2, spec.n_coeffs_per_atom)).astype(np.float64) + flat = pack_coeffs_to_point_labels(coeffs, spec, ("Fe", "Fe")) + assert flat.dtype == np.float64 + + def test_atoms_concatenated_in_input_order(self): + """Per-atom blocks must appear contiguously and in the order of + the symbols argument, so the flat vector aligns with the graph + node order Graph2Mat builds from the structure.""" + from graph2mat_ft.projection import pack_coeffs_to_point_labels + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + per_atom = spec.n_coeffs_per_atom + coeffs = np.zeros((2, per_atom)) + coeffs[0, :] = 1.0 + coeffs[1, :] = 2.0 + flat = pack_coeffs_to_point_labels(coeffs, spec, ("Fe", "O")) + assert np.allclose(flat[:per_atom], 1.0) + assert np.allclose(flat[per_atom:], 2.0) + + def test_within_atom_order_preserved(self): + """Within one atom's block, channels must keep their input order + (no reordering across the channel axis). This is the + load-bearing contract for matching what the Graph2Mat model + head learns to emit.""" + from graph2mat_ft.projection import pack_coeffs_to_point_labels + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + per_atom = spec.n_coeffs_per_atom + coeffs = np.arange(per_atom, dtype=np.float64).reshape(1, per_atom) + flat = pack_coeffs_to_point_labels(coeffs, spec, ("Fe",)) + np.testing.assert_array_equal(flat, np.arange(per_atom, dtype=np.float64)) + + def test_empty_structure_returns_empty(self): + from graph2mat_ft.projection import pack_coeffs_to_point_labels + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + flat = pack_coeffs_to_point_labels( + np.zeros((0, spec.n_coeffs_per_atom)), spec, () + ) + assert flat.shape == (0,) + + def test_symbol_length_mismatch_raises(self): + """N_atoms in coeffs must match len(symbols). Catching this at + the boundary stops a silent off-by-one from polluting the + training set.""" + from graph2mat_ft.projection import pack_coeffs_to_point_labels + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + coeffs = np.zeros((2, spec.n_coeffs_per_atom)) + with pytest.raises(ValueError): + pack_coeffs_to_point_labels(coeffs, spec, ("Fe",)) + + def test_wrong_channel_width_raises(self): + """coeffs.shape[1] must equal spec.n_coeffs_per_atom.""" + from graph2mat_ft.projection import pack_coeffs_to_point_labels + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + coeffs = np.zeros((1, spec.n_coeffs_per_atom + 1)) + with pytest.raises(ValueError): + pack_coeffs_to_point_labels(coeffs, spec, ("Fe",)) + + +class TestUnpackPointLabelsToCoeffs: + def test_output_shape(self): + from graph2mat_ft.projection import unpack_point_labels_to_coeffs + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + flat = np.zeros(2 * spec.n_coeffs_per_atom) + coeffs = unpack_point_labels_to_coeffs(flat, spec, ("Fe", "O")) + assert coeffs.shape == (2, spec.n_coeffs_per_atom) + + def test_wrong_length_raises(self): + from graph2mat_ft.projection import unpack_point_labels_to_coeffs + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + bad = np.zeros(2 * spec.n_coeffs_per_atom + 1) + with pytest.raises(ValueError): + unpack_point_labels_to_coeffs(bad, spec, ("Fe", "O")) + + +class TestRoundtrip: + def test_roundtrip_single_atom(self): + from graph2mat_ft.projection import ( + pack_coeffs_to_point_labels, + unpack_point_labels_to_coeffs, + ) + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + rng = np.random.default_rng(1) + coeffs = rng.standard_normal((1, spec.n_coeffs_per_atom)) + flat = pack_coeffs_to_point_labels(coeffs, spec, ("Fe",)) + restored = unpack_point_labels_to_coeffs(flat, spec, ("Fe",)) + np.testing.assert_array_equal(restored, coeffs) + + def test_roundtrip_multi_atom_mixed_species(self): + from graph2mat_ft.projection import ( + pack_coeffs_to_point_labels, + unpack_point_labels_to_coeffs, + ) + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + rng = np.random.default_rng(2) + symbols = ("Fe", "O", "Fe", "H", "O") + coeffs = rng.standard_normal((len(symbols), spec.n_coeffs_per_atom)) + flat = pack_coeffs_to_point_labels(coeffs, spec, symbols) + restored = unpack_point_labels_to_coeffs(flat, spec, symbols) + np.testing.assert_array_equal(restored, coeffs) + + +class TestBasisConfiguration: + """Bundle structure + symbols + (optional) coefficients into a + Graph2Mat-ready container. Used by the ZETA-GAMMA training + driver. Lazy-imports graph2mat so test only runs when the dep is + installed (it is in our pyproject).""" + + def test_returns_basisconfiguration_instance(self): + pytest.importorskip("graph2mat") + from graph2mat import BasisConfiguration + + from graph2mat_ft.projection import make_basis_configuration + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + positions = np.array([[0.0, 0.0, 0.0], [2.0, 2.0, 2.0]]) + cell = np.eye(3) * 4.0 + symbols = ("Fe", "O") + cfg = make_basis_configuration(positions, cell, symbols, spec) + assert isinstance(cfg, BasisConfiguration) + + def test_point_types_indexes_into_basis(self): + """point_types[i] must point at the PointBasis whose type + equals symbols[i]. If this drifts Graph2Mat assigns the wrong + per-species head to each atom.""" + pytest.importorskip("graph2mat") + + from graph2mat_ft.projection import make_basis_configuration + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + positions = np.array([[0.0, 0.0, 0.0], [2.0, 2.0, 2.0]]) + cell = np.eye(3) * 4.0 + symbols = ("Fe", "O") + cfg = make_basis_configuration(positions, cell, symbols, spec) + # Graph2Mat resolves point_types as indices into the cfg.basis list + types_via_basis = [cfg.basis[t].type for t in cfg.point_types] + assert tuple(types_via_basis) == symbols + + def test_positions_and_cell_round_trip(self): + pytest.importorskip("graph2mat") + + from graph2mat_ft.projection import make_basis_configuration + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + positions = np.array([[0.1, 0.2, 0.3], [1.5, 1.5, 1.5]]) + cell = np.diag([3.0, 4.0, 5.0]) + cfg = make_basis_configuration(positions, cell, ("Fe", "O"), spec) + np.testing.assert_allclose(cfg.positions, positions) + np.testing.assert_allclose(cfg.cell, cell) diff --git a/tests/test_salted_baseline.py b/tests/test_salted_baseline.py new file mode 100644 index 0000000..f520575 --- /dev/null +++ b/tests/test_salted_baseline.py @@ -0,0 +1,259 @@ +"""TDD tests for ``salted_ft.train_baseline`` (D6 path B). + +A pragmatic PyTorch baseline that predicts per-atom basis +coefficients (the same target SALTED projects to). Architecture is +a small SchNet-style invariant message-passing net + linear +readout to ``n_coeffs_per_atom`` channels. Loss is MSE on the +ground-truth coefficient vectors from D2. + +Tests cover the model contract and the training-loop sanity +check (loss must decrease over a few steps). Real Adastra runs +validate end-to-end NMAPE on the held-out split. +""" + +from __future__ import annotations + +from pathlib import Path + +import ase +import numpy as np +import pandas as pd +import pytest + + +def _h2_atoms() -> ase.Atoms: + return ase.Atoms( + "HH", + positions=[[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]], + cell=np.eye(3) * 5.0, + pbc=True, + ) + + +def _feo_atoms() -> ase.Atoms: + return ase.Atoms( + "FeO", + positions=[[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]], + cell=np.eye(3) * 4.0, + pbc=True, + ) + + +class TestModelForward: + def test_output_shape(self): + pytest.importorskip("torch") + from salted_ft.basis import BasisSpec + from salted_ft.train_baseline import SaltedBaselineModel + + m = SaltedBaselineModel(BasisSpec()) + out = m(_h2_atoms()) + assert out.shape == (2, BasisSpec().n_coeffs_per_atom) + + def test_output_finite(self): + pytest.importorskip("torch") + from salted_ft.basis import BasisSpec + from salted_ft.train_baseline import SaltedBaselineModel + + import torch + + m = SaltedBaselineModel(BasisSpec()) + out = m(_feo_atoms()) + assert torch.isfinite(out).all() + + def test_output_dtype_is_float32(self): + pytest.importorskip("torch") + import torch + + from salted_ft.basis import BasisSpec + from salted_ft.train_baseline import SaltedBaselineModel + + m = SaltedBaselineModel(BasisSpec()) + out = m(_h2_atoms()) + assert out.dtype == torch.float32 + + def test_deterministic_with_same_seed(self): + """Same model state + same atoms in -> same coefficients out. + Required for the eval pipeline to be reproducible.""" + pytest.importorskip("torch") + import torch + + from salted_ft.basis import BasisSpec + from salted_ft.train_baseline import SaltedBaselineModel + + torch.manual_seed(0) + m1 = SaltedBaselineModel(BasisSpec()) + torch.manual_seed(0) + m2 = SaltedBaselineModel(BasisSpec()) + out1 = m1(_h2_atoms()) + out2 = m2(_h2_atoms()) + torch.testing.assert_close(out1, out2) + + def test_different_species_changes_output(self): + """Species embedding must carry signal. If H and Fe atoms with + identical positions give identical outputs the embedding is + ignored.""" + pytest.importorskip("torch") + import torch + + from salted_ft.basis import BasisSpec + from salted_ft.train_baseline import SaltedBaselineModel + + torch.manual_seed(0) + m = SaltedBaselineModel(BasisSpec()) + a_hh = ase.Atoms( + "HH", positions=[[0, 0, 0], [2, 0, 0]], cell=np.eye(3) * 5.0, pbc=True + ) + a_he = ase.Atoms( + "HHe", positions=[[0, 0, 0], [2, 0, 0]], cell=np.eye(3) * 5.0, pbc=True + ) + out_hh = m(a_hh) + out_he = m(a_he) + assert not torch.allclose(out_hh, out_he) + + +class TestTrainingStep: + def test_loss_decreases_after_few_steps(self): + """Sanity: optimiser can drive the loss down on a tiny dataset. + Catches obvious wiring bugs (no grads flowing, frozen embedding).""" + pytest.importorskip("torch") + import torch + + from salted_ft.basis import BasisSpec + from salted_ft.train_baseline import SaltedBaselineModel + + torch.manual_seed(0) + spec = BasisSpec() + model = SaltedBaselineModel(spec) + opt = torch.optim.Adam(model.parameters(), lr=1e-2) + atoms = _feo_atoms() + target = torch.randn(len(atoms), spec.n_coeffs_per_atom) * 0.1 + + # Loss before any training + with torch.no_grad(): + loss_before = torch.nn.functional.mse_loss(model(atoms), target).item() + + for _ in range(20): + opt.zero_grad() + pred = model(atoms) + loss = torch.nn.functional.mse_loss(pred, target) + loss.backward() + opt.step() + + with torch.no_grad(): + loss_after = torch.nn.functional.mse_loss(model(atoms), target).item() + assert loss_after < loss_before, ( + f"loss did not decrease: before={loss_before:.6f}, after={loss_after:.6f}" + ) + + +class TestSaveLoad: + def test_save_load_preserves_predictions(self, tmp_path): + pytest.importorskip("torch") + import torch + + from salted_ft.basis import BasisSpec + from salted_ft.train_baseline import SaltedBaselineModel + + torch.manual_seed(0) + spec = BasisSpec() + m = SaltedBaselineModel(spec) + atoms = _h2_atoms() + out_before = m(atoms) + + ckpt = tmp_path / "model.pt" + torch.save({"basis_spec": spec, "model": m.state_dict()}, ckpt) + + m2 = SaltedBaselineModel(spec) + state = torch.load(ckpt, map_location="cpu", weights_only=False) + m2.load_state_dict(state["model"]) + out_after = m2(atoms) + torch.testing.assert_close(out_before, out_after) + + +def _toy_dataset_dirs(tmp_path: Path, basis_spec, n_rows: int = 2): + """Create matched D2 source + projected parquets in two subdirs. + + The training dataset joins them by ``row_index`` per chunk; the + file basename matches across the two directories so ``chunk_0000`` + in ``source/`` lines up with ``chunk_0000`` in ``coeffs/``. + """ + src_dir = tmp_path / "charge3net_data" + coeffs_dir = tmp_path / "salted_projected_coefficients" + src_dir.mkdir() + coeffs_dir.mkdir() + rng = np.random.default_rng(0) + + src_rows = [] + coeffs_rows = [] + for i in range(n_rows): + n_atoms = 2 + atomic_numbers = [1, 1] + positions = [[0.0, 0.0, 0.0], [0.74 + 0.01 * i, 0.0, 0.0]] + cell = (np.eye(3) * 5.0).tolist() + src_rows.append( + { + "row_index": i, + "material_id": f"mp-{i}", + "n_atoms": n_atoms, + "atomic_numbers": atomic_numbers, + "cartesian_site_positions": [c for row in positions for c in row], + "lattice_vectors": [c for row in cell for c in row], + # Tiny grid so this stays cheap; the projected file is what + # the training loop actually consumes + "grid_shape": [4, 4, 4], + "compressed_charge_density": rng.standard_normal(np.prod((4, 4, 4))) + .astype(np.float32) + .tobytes(), + } + ) + coeffs_rows.append( + { + "row_index": i, + "material_id": f"mp-{i}", + "n_atoms": n_atoms, + "atomic_numbers": atomic_numbers, + "lattice_vectors": cell, + "n_electrons": 2.0, + "grid_shape": [4, 4, 4], + "coefficients": rng.standard_normal( + (n_atoms, basis_spec.n_coeffs_per_atom) + ).tolist(), + "basis_set_NMAPE": 5.0, + } + ) + pd.DataFrame(src_rows).to_parquet(src_dir / "chunk_0000.parquet") + pd.DataFrame(coeffs_rows).to_parquet(coeffs_dir / "chunk_0000.parquet") + return src_dir, coeffs_dir + + +class TestTrainCLI: + """Higher-level: ``train`` end-to-end on a synthetic 2-row dataset. + + Validates that the full data path (parquet pair -> dataset -> + training loop -> ckpt) works without crashing. Real ckpts come + from running ``submit_salted_baseline_adastra.sh``. + """ + + def test_train_writes_ckpt(self, tmp_path): + pytest.importorskip("torch") + + from salted_ft.basis import BasisSpec + from salted_ft.train_baseline import train + + spec = BasisSpec() + src_dir, coeffs_dir = _toy_dataset_dirs(tmp_path, spec, n_rows=2) + ckpt = tmp_path / "salted_baseline.pt" + train( + source_dir=src_dir, + coeffs_dir=coeffs_dir, + output_ckpt=ckpt, + basis_spec=spec, + n_epochs=1, + batch_size=1, + learning_rate=1e-3, + ) + assert ckpt.exists() + import torch + + state = torch.load(ckpt, map_location="cpu", weights_only=False) + assert "model" in state diff --git a/tests/test_salted_basis.py b/tests/test_salted_basis.py new file mode 100644 index 0000000..b8ce7f4 --- /dev/null +++ b/tests/test_salted_basis.py @@ -0,0 +1,155 @@ +"""TDD tests for the SALTED-arm BasisSpec dataclass. + +Locks down the basis numbers chosen in +``plan_salted_graph2mat_basis_choice_may_20_pm.md`` (Phase A4): + +* ``max_l = 4`` +* ``n_radial = 4`` (uniform across species in v1) +* ``sigma = (0.5, 1.0, 2.0, 4.0)`` Å — geometric radial-width ladder +* ``cutoff = 4.0`` Å — matches ChargE3Net's KdTree cutoff +* ``n_coeffs_per_atom == n_radial * (max_l + 1) ** 2`` == 100 + +These numbers are referenced by every downstream PR (projection, +reconstruction, model wrapper, VASP I/O). Pinning them here means a +later edit shows up as a single failing test, not a silent drift. +""" + +from __future__ import annotations + +import pytest + + +class TestBasisSpecDefaults: + """Default BasisSpec must match the A4 lockdown.""" + + def test_default_max_l_is_four(self): + from salted_ft.basis import BasisSpec + + assert BasisSpec().max_l == 4 + + def test_default_n_radial_is_four(self): + from salted_ft.basis import BasisSpec + + assert BasisSpec().n_radial == 4 + + def test_default_sigma_ladder(self): + from salted_ft.basis import BasisSpec + + # Geometric ladder over tight + valence + diffuse regimes. + assert BasisSpec().sigma == (0.5, 1.0, 2.0, 4.0) + + def test_default_cutoff_matches_charge3net(self): + from salted_ft.basis import BasisSpec + + # ChargE3Net's KdTreeGraphConstructor uses cutoff=4.0; the SALTED-arm + # uses the same so atom-neighbor structure is identical between models. + assert BasisSpec().cutoff == pytest.approx(4.0) + + def test_default_n_coeffs_per_atom_is_100(self): + """4 radial * (4+1)^2 angular = 100 coefficients per atom.""" + from salted_ft.basis import BasisSpec + + assert BasisSpec().n_coeffs_per_atom == 100 + + +class TestBasisSpecArithmetic: + """n_coeffs_per_atom must equal n_radial * (max_l + 1)^2 for any valid spec.""" + + @pytest.mark.parametrize( + "max_l,n_radial,expected", + [ + (0, 1, 1), # one s function + (1, 2, 8), # 2 * (1 + 3) = 8 + (2, 3, 27), # 3 * (1 + 3 + 5) = 27 + (4, 4, 100), # the production default + (6, 4, 196), # 4 * (1 + 3 + 5 + 7 + 9 + 11 + 13) = 196 + ], + ) + def test_n_coeffs_formula(self, max_l, n_radial, expected): + from salted_ft.basis import BasisSpec + + spec = BasisSpec( + max_l=max_l, + n_radial=n_radial, + sigma=tuple(0.5 * 2**i for i in range(n_radial)), + cutoff=5.0, + ) + assert spec.n_coeffs_per_atom == expected + + def test_n_radial_matches_sigma_length(self): + """sigma is the per-radial-channel width; len(sigma) must equal n_radial.""" + from salted_ft.basis import BasisSpec + + with pytest.raises(ValueError, match=r"n_radial.*sigma"): + BasisSpec(max_l=2, n_radial=3, sigma=(0.5, 1.0), cutoff=4.0) + + +class TestBasisSpecValidation: + """Reject malformed specs at construction time, not at use time.""" + + def test_negative_max_l_rejected(self): + from salted_ft.basis import BasisSpec + + with pytest.raises(ValueError, match=r"max_l"): + BasisSpec(max_l=-1, n_radial=4, sigma=(0.5, 1.0, 2.0, 4.0), cutoff=4.0) + + def test_zero_n_radial_rejected(self): + from salted_ft.basis import BasisSpec + + with pytest.raises(ValueError, match=r"n_radial"): + BasisSpec(max_l=4, n_radial=0, sigma=(), cutoff=4.0) + + def test_negative_sigma_rejected(self): + """sigma is a Gaussian width; nonpositive widths are nonphysical.""" + from salted_ft.basis import BasisSpec + + with pytest.raises(ValueError, match=r"sigma"): + BasisSpec(max_l=2, n_radial=2, sigma=(0.5, -1.0), cutoff=4.0) + + def test_nonpositive_cutoff_rejected(self): + from salted_ft.basis import BasisSpec + + with pytest.raises(ValueError, match=r"cutoff"): + BasisSpec(max_l=2, n_radial=2, sigma=(0.5, 1.0), cutoff=0.0) + + +class TestBasisSpecShapes: + """Shape helpers for downstream tensor allocation.""" + + def test_n_angular_components_per_radial(self): + """(max_l + 1)^2 real spherical harmonic components per radial channel.""" + from salted_ft.basis import BasisSpec + + # l=0,1,2,3,4 -> 1+3+5+7+9 = 25 angular components per radial channel + assert BasisSpec().n_angular_components == 25 + + def test_total_coeffs_shape(self): + """coeffs tensor shape for a structure: (n_atoms, n_coeffs_per_atom).""" + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + assert spec.total_coeffs_shape(n_atoms=5) == (5, 100) + assert spec.total_coeffs_shape(n_atoms=1) == (1, 100) + + +class TestBasisSpecImmutable: + """BasisSpec must be hashable + immutable so it can key caches / metric runs.""" + + def test_is_hashable(self): + from salted_ft.basis import BasisSpec + + # Two specs with identical fields hash to the same value. + a = BasisSpec() + b = BasisSpec() + assert hash(a) == hash(b) + assert a == b + + def test_mutation_rejected(self): + """Frozen dataclass — assigning to a field raises FrozenInstanceError.""" + from dataclasses import FrozenInstanceError + + from salted_ft.basis import BasisSpec + + spec = BasisSpec() + with pytest.raises(FrozenInstanceError): + spec.max_l = 6 # type: ignore[misc] diff --git a/tests/test_salted_io.py b/tests/test_salted_io.py new file mode 100644 index 0000000..61a00e1 --- /dev/null +++ b/tests/test_salted_io.py @@ -0,0 +1,207 @@ +"""TDD tests for VASP CHGCAR I/O wrapper (PR delta). + +The wrapper exposes ``write_chgcar(density, atoms, path)`` so a +reconstructed real-space density grid can be persisted as a VASP +CHGCAR file. That file is then the input to a paired SCF run +(``ICHARG=1``) for the speedup comparison vs the +``ICHARG=2``-from-superposition baseline. + +Locked contract: + +* ``write_chgcar(density, atoms, path, n_electrons=None)`` + Writes a pymatgen ``Chgcar``-compatible file at ``path``. If + ``n_electrons`` is given, rescales the density so that + ``sum(density) * cell_volume / N_grid == n_electrons``. +* The written file round-trips through ``Chgcar.from_file`` and + preserves shape, atom species, and cell. +* ``read_chgcar(path)`` is the inverse: returns + ``(density: np.ndarray, atoms: ase.Atoms)``. + +End-to-end SCF speedup test is gated on the entalsim +``StructureVASPSinglePoint`` maker landing; pinned here as an +``importorskip`` placeholder so it auto-activates when the +dependency arrives. +""" + +from __future__ import annotations + +import tempfile +from pathlib import Path + +import ase +import numpy as np +import pytest + + +def _cubic_atoms(symbols=("Fe",), fractional=((0.5, 0.5, 0.5),), a=4.0): + cell = np.eye(3) * a + cart = np.array(fractional) @ cell + return ase.Atoms(symbols=list(symbols), positions=cart, cell=cell, pbc=True) + + +class TestWriteChgcar: + def test_writes_file(self): + from salted_ft.io import write_chgcar + + atoms = _cubic_atoms() + density = np.ones((8, 8, 8), dtype=np.float64) + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + assert path.exists() + assert path.stat().st_size > 0 + + def test_normalizes_to_total_electron_count(self): + """When ``n_electrons`` is set, the *integrated* density of the + written file must equal ``n_electrons`` to within ``1e-6 * n_electrons``. + That's what VASP reads as N_electrons on ICHARG=1. + """ + from salted_ft.io import read_chgcar, write_chgcar + + atoms = _cubic_atoms() + # Density that integrates to something arbitrary; write_chgcar + # should rescale to the requested electron count. + density = np.ones((8, 8, 8), dtype=np.float64) * 0.5 + target_n = 26.0 # Fe valence electron count, roughly + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path, n_electrons=target_n) + read_density, _ = read_chgcar(path) + # CHGCAR convention: density * volume / N_grid integrates to N_electrons + cell_volume = np.linalg.det(atoms.get_cell()) + n_grid = np.prod(read_density.shape) + total_e = read_density.sum() * cell_volume / n_grid + assert abs(total_e - target_n) / target_n < 1e-4, ( + f"integrated density {total_e:.6f} differs from target {target_n} " + "by more than 1e-4; CHGCAR normalization is wrong" + ) + + def test_rejects_non_3d_density(self): + from salted_ft.io import write_chgcar + + atoms = _cubic_atoms() + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + with pytest.raises(ValueError, match=r"3D"): + write_chgcar(np.ones((8, 8)), atoms, path) + + def test_rejects_negative_n_electrons(self): + from salted_ft.io import write_chgcar + + atoms = _cubic_atoms() + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + with pytest.raises(ValueError, match=r"n_electrons"): + write_chgcar(np.ones((8, 8, 8)), atoms, path, n_electrons=-1.0) + + +class TestReadChgcar: + def test_returns_density_and_atoms(self): + from salted_ft.io import read_chgcar, write_chgcar + + atoms = _cubic_atoms() + density = np.ones((8, 8, 8), dtype=np.float64) * 0.1 + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + read_density, read_atoms = read_chgcar(path) + assert read_density.shape == (8, 8, 8) + assert isinstance(read_atoms, ase.Atoms) + + def test_preserves_atom_species(self): + from salted_ft.io import read_chgcar, write_chgcar + + atoms = _cubic_atoms( + symbols=("Fe", "O"), fractional=((0.0, 0.0, 0.0), (0.5, 0.5, 0.5)) + ) + density = np.ones((8, 8, 8), dtype=np.float64) * 0.1 + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + _, read_atoms = read_chgcar(path) + # Order may differ but the multiset of species must match. + assert sorted(read_atoms.get_chemical_symbols()) == sorted( + atoms.get_chemical_symbols() + ) + + def test_preserves_cell(self): + from salted_ft.io import read_chgcar, write_chgcar + + atoms = _cubic_atoms(a=5.0) + density = np.ones((4, 4, 4), dtype=np.float64) * 0.05 + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + _, read_atoms = read_chgcar(path) + np.testing.assert_allclose( + np.asarray(read_atoms.get_cell()), + np.asarray(atoms.get_cell()), + atol=1e-6, + ) + + +class TestRoundtrip: + def test_density_roundtrip_within_tolerance(self): + """Write then read: shape exact, values within VASP-precision tolerance. + + VASP CHGCAR uses 5-decimal scientific notation per value, so + we expect ~1e-5 relative precision. + """ + from salted_ft.io import read_chgcar, write_chgcar + + atoms = _cubic_atoms() + rng = np.random.default_rng(7) + density = rng.random((8, 8, 8)).astype(np.float64) * 0.1 + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + read_density, _ = read_chgcar(path) + assert read_density.shape == density.shape + np.testing.assert_allclose(read_density, density, rtol=1e-3, atol=1e-5) + + +class TestSALTEDModelToChgcar: + """End-to-end: predict via SALTEDModel, reconstruct, write CHGCAR.""" + + def test_predicted_density_writes_to_chgcar(self): + from salted_ft.basis import BasisSpec + from salted_ft.io import read_chgcar, write_chgcar + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms() + model = SALTEDModel(basis_spec=BasisSpec()) + density = model.reconstruct_density(atoms, (8, 8, 8)) + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, path) + assert path.exists() + read_density, _ = read_chgcar(path) + assert read_density.shape == (8, 8, 8) + + +# --------------------------------------------------------------------------- +# Forward-looking placeholder for the entalsim integration. +# +# Once Entalpic/entalsim PR #56's PR 2 (StructureVASPSinglePoint maker) +# lands and is installable, this test will auto-activate. Until then it +# skips cleanly so the suite stays green. +# --------------------------------------------------------------------------- +class TestVASPSinglePointHook: + def test_chgcar_consumed_by_entalsim_single_point_maker(self): + # Skips until entalsim ships the maker. + pytest.importorskip("entalsim.dft.tasks.single_point") + from entalsim.dft.tasks.single_point import StructureVASPSinglePoint + + from salted_ft.basis import BasisSpec + from salted_ft.io import write_chgcar + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms() + model = SALTEDModel(basis_spec=BasisSpec()) + density = model.reconstruct_density(atoms, (8, 8, 8)) + with tempfile.TemporaryDirectory() as tmp: + chgcar = Path(tmp) / "CHGCAR" + write_chgcar(density, atoms, chgcar) + # Maker should accept the written CHGCAR for ICHARG=1. + maker = StructureVASPSinglePoint(initial_chgcar=chgcar) + assert maker.initial_chgcar == chgcar diff --git a/tests/test_salted_model.py b/tests/test_salted_model.py new file mode 100644 index 0000000..7977611 --- /dev/null +++ b/tests/test_salted_model.py @@ -0,0 +1,303 @@ +"""TDD tests for the SALTEDModel wrapper (PR gamma). + +The wrapper exposes ``__call__(atoms) -> coefficients`` so SALTED-style +predictions plug into the projection / reconstruction layer from PR beta. + +Locked contract: + +* ``SALTEDModel(basis_spec, ckpt_path=None)`` — construct. When + ``ckpt_path`` is None the wrapper produces deterministic + position-dependent stub coefficients (lets us run tests + the + reconstruction pipeline without a real rholearn checkpoint). + +* ``model(atoms)`` returns ``np.ndarray (n_atoms, n_coeffs_per_atom)``, + float64, finite, deterministic for fixed inputs. + +* ``model.reconstruct_density(atoms, grid_shape)`` returns the density + grid in the same shape ``reconstruct_grid_from_basis`` would have + produced from the predicted coefficients. Convenience method for the + VASP comparison pipeline. + +* Metric integration: the predicted density grid feeds into + ``compute_nmape`` / ``compute_rmse`` / ``compute_nrmse`` from + ``charge3net_ft.train`` and they return finite scalars. Pinned per the + brief: "Keep the metric calculations identical to our ChargE3Net pipeline." +""" + +from __future__ import annotations + +import ase +import numpy as np +import torch + + +def _cubic_atoms(symbols=("Fe",), fractional=((0.0, 0.0, 0.0),), a=4.0): + cell = np.eye(3) * a + cart = np.array(fractional) @ cell + return ase.Atoms(symbols=list(symbols), positions=cart, cell=cell, pbc=True) + + +class TestSALTEDModelConstruct: + def test_constructs_with_basis_spec(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + spec = BasisSpec() + m = SALTEDModel(basis_spec=spec) + assert m.basis_spec is spec + + def test_default_ckpt_is_none(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + assert m.ckpt_path is None + + +class TestSALTEDModelOutputShape: + def test_single_atom_output_shape(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + spec = BasisSpec() + m = SALTEDModel(basis_spec=spec) + coeffs = m(_cubic_atoms()) + assert coeffs.shape == (1, spec.n_coeffs_per_atom) + + def test_multi_atom_output_shape(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + spec = BasisSpec() + m = SALTEDModel(basis_spec=spec) + atoms = _cubic_atoms( + symbols=("Fe", "O", "Fe"), + fractional=((0.0, 0.0, 0.0), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), + ) + coeffs = m(atoms) + assert coeffs.shape == (3, spec.n_coeffs_per_atom) + + def test_output_dtype_is_float64(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + coeffs = m(_cubic_atoms()) + assert coeffs.dtype == np.float64 + + def test_output_is_finite(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + coeffs = m(_cubic_atoms()) + assert np.isfinite(coeffs).all() + + +class TestSALTEDModelDeterminism: + def test_same_input_gives_same_output(self): + """Reproducibility: critical for CI + regression tests.""" + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + atoms = _cubic_atoms( + symbols=("Fe", "Fe"), fractional=((0.1, 0.2, 0.3), (0.4, 0.5, 0.6)) + ) + c1 = m(atoms) + c2 = m(atoms) + np.testing.assert_array_equal(c1, c2) + + def test_different_positions_give_different_coefficients(self): + """A degenerate stub that always returned zeros would pass shape + + determinism but be useless. Require some position-dependent + variation so downstream tests have signal to work with. + """ + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + atoms_a = _cubic_atoms(fractional=((0.0, 0.0, 0.0),)) + atoms_b = _cubic_atoms(fractional=((0.5, 0.5, 0.5),)) + c_a = m(atoms_a) + c_b = m(atoms_b) + assert not np.allclose(c_a, c_b), ( + "predicted coefficients must depend on atom positions; the stub " + "appears to return position-independent constants" + ) + + def test_baseline_ckpt_loads_and_predicts(self, tmp_path): + """Real-mode path: save a D6 baseline ckpt, instantiate SALTEDModel + with its path, and verify forward returns the expected shape and + the prediction differs from stub-mode output (so we know the + ckpt actually drove the result).""" + import pytest + + pytest.importorskip("torch") + import torch + + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + from salted_ft.train_baseline import SaltedBaselineModel + + spec = BasisSpec() + torch.manual_seed(42) + baseline = SaltedBaselineModel(spec) + ckpt = tmp_path / "salted_baseline.pt" + torch.save({"basis_spec": spec, "model": baseline.state_dict()}, ckpt) + + atoms = _cubic_atoms( + symbols=("Fe", "Fe"), fractional=((0.1, 0.2, 0.3), (0.4, 0.5, 0.6)) + ) + + m_stub = SALTEDModel(spec) + m_loaded = SALTEDModel(spec, ckpt_path=ckpt) + + out_stub = m_stub(atoms) + out_loaded = m_loaded(atoms) + + assert out_loaded.shape == (2, spec.n_coeffs_per_atom) + assert not np.allclose(out_loaded, out_stub), ( + "loaded ckpt produced the same output as the stub seed; " + "the ckpt path likely is not being exercised" + ) + + def test_bad_ckpt_format_raises_clearly(self, tmp_path): + import pytest + + pytest.importorskip("torch") + import torch + + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + ckpt = tmp_path / "bad.pt" + torch.save({"not_a_baseline": "anything"}, ckpt) + m = SALTEDModel(BasisSpec(), ckpt_path=ckpt) + atoms = _cubic_atoms() + with pytest.raises(RuntimeError, match="baseline format"): + m(atoms) + + def test_perturbing_non_first_atom_changes_coefficients(self): + """Regression test for the int.from_bytes(seed_bytes[:16], ...) + bug: with the old seeding, only atom 0's xyz (the first 24 + bytes) contributed to the seed, so perturbing atom 1+ produced + identical coefficients. The blake2b hash fixes this. + """ + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + atoms_a = _cubic_atoms( + symbols=("Fe", "Fe"), fractional=((0.0, 0.0, 0.0), (0.5, 0.5, 0.5)) + ) + atoms_b = _cubic_atoms( + symbols=("Fe", "Fe"), fractional=((0.0, 0.0, 0.0), (0.6, 0.5, 0.5)) + ) + c_a = m(atoms_a) + c_b = m(atoms_b) + assert not np.array_equal(c_a, c_b), ( + "perturbing atom 1 must change the coefficient output; " + "if not, the stub seed only uses atom 0's bytes" + ) + + +class TestSALTEDModelReconstructDensity: + def test_reconstruct_density_shape(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + grid = m.reconstruct_density(_cubic_atoms(), (8, 8, 8)) + assert grid.shape == (8, 8, 8) + + def test_reconstruct_density_matches_explicit_path(self): + """``model.reconstruct_density(atoms, shape)`` must equal calling + ``model(atoms)`` then ``reconstruct_grid_from_basis(c, ...)``. + Convenience method is just sugar. + """ + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms( + symbols=("Fe", "O"), fractional=((0.0, 0.0, 0.0), (0.5, 0.5, 0.5)) + ) + m = SALTEDModel(basis_spec=spec) + c = m(atoms) + expected = reconstruct_grid_from_basis(c, atoms, (8, 8, 8), spec) + got = m.reconstruct_density(atoms, (8, 8, 8)) + np.testing.assert_array_equal(got, expected) + + def test_reconstruct_density_dtype_and_finite(self): + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + m = SALTEDModel(basis_spec=BasisSpec()) + grid = m.reconstruct_density(_cubic_atoms(), (8, 8, 8)) + assert grid.dtype == np.float64 + assert np.isfinite(grid).all() + + +class TestMetricIntegration: + """Predicted density grid feeds the existing ChargE3Net metric functions.""" + + def _to_torch_batch(self, grid: np.ndarray) -> torch.Tensor: + """Flatten a (Nx, Ny, Nz) grid into a (B=1, N_probes) torch tensor. + + ChargE3Net's compute_nmape signature is (preds, targets, num_probes). + For full-grid evaluation we use B=1 and num_probes=None. + """ + return torch.from_numpy(grid.astype(np.float32).reshape(1, -1)) + + def test_compute_nmape_returns_finite_scalar(self): + from charge3net_ft.train import compute_nmape + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms(fractional=((0.5, 0.5, 0.5),)) + m = SALTEDModel(basis_spec=BasisSpec()) + preds = self._to_torch_batch(m.reconstruct_density(atoms, (8, 8, 8))) + # Synthetic target: same shape, non-zero so the NMAPE denominator is positive + targets = torch.ones_like(preds) + nmape = compute_nmape(preds, targets, num_probes=None) + assert nmape.numel() == 1 + assert torch.isfinite(nmape).all() + + def test_compute_rmse_returns_finite_scalar(self): + from charge3net_ft.train import compute_rmse + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms(fractional=((0.5, 0.5, 0.5),)) + m = SALTEDModel(basis_spec=BasisSpec()) + preds = self._to_torch_batch(m.reconstruct_density(atoms, (8, 8, 8))) + targets = torch.ones_like(preds) + rmse = compute_rmse(preds, targets, num_probes=None) + assert torch.isfinite(rmse).all() + + def test_compute_nrmse_returns_finite_scalar(self): + from charge3net_ft.train import compute_nrmse + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms(fractional=((0.5, 0.5, 0.5),)) + m = SALTEDModel(basis_spec=BasisSpec()) + preds = self._to_torch_batch(m.reconstruct_density(atoms, (8, 8, 8))) + targets = torch.ones_like(preds) + nrmse = compute_nrmse(preds, targets, num_probes=None) + assert torch.isfinite(nrmse).all() + + def test_perfect_prediction_gives_zero_nmape(self): + """Sanity check: NMAPE of a tensor against itself is zero.""" + from charge3net_ft.train import compute_nmape + from salted_ft.basis import BasisSpec + from salted_ft.model import SALTEDModel + + atoms = _cubic_atoms(fractional=((0.5, 0.5, 0.5),)) + m = SALTEDModel(basis_spec=BasisSpec()) + preds = self._to_torch_batch(m.reconstruct_density(atoms, (8, 8, 8))) + # Self-similarity: target identical to prediction => zero error. + nmape = compute_nmape(preds, preds.clone(), num_probes=None) + assert nmape.item() == 0.0 diff --git a/tests/test_salted_project_dataset.py b/tests/test_salted_project_dataset.py new file mode 100644 index 0000000..f9c8253 --- /dev/null +++ b/tests/test_salted_project_dataset.py @@ -0,0 +1,236 @@ +"""TDD tests for the Phase D2 dataset-projection module. + +Locks the contract for ``salted_ft.project_dataset.project_chunk``, +which reads a LeMat-Rho-format parquet chunk, runs +``project_chgcar_to_basis`` row by row, and writes a parallel parquet +chunk of projected coefficients. + +Output schema per row:: + + { + "row_index": int (matches the original chunk row index), + "material_id": str (carried through if present, else "" ), + "n_atoms": int, + "atomic_numbers": list[int], + "lattice_vectors": list[list[float]], # 3x3 + "n_electrons": float (integrated density * cell_volume / n_grid), + "grid_shape": list[int], # [Nx, Ny, Nz] + "coefficients": list[list[float]], # (n_atoms, n_coeffs_per_atom) + "basis_set_NMAPE": float (basis-ceiling NMAPE for this row), + } + +The basis_set_NMAPE column is the per-row reconstruction error from +roundtripping; we keep it so downstream sanity-checks can know each +sample's basis ceiling. +""" + +from __future__ import annotations + +import json +import tempfile +from pathlib import Path + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + + +def _write_synthetic_chunk(path: Path, n_rows: int = 3) -> None: + """Write a LeMat-Rho-format chunk for use by the projection script.""" + rng = np.random.default_rng(42) + grids = [ + json.dumps(rng.random((10, 10, 10), dtype=np.float64).tolist()) + for _ in range(n_rows) + ] + table = pa.table( + { + "compressed_charge_density": pa.array(grids, type=pa.string()), + "species_at_sites": pa.array([["Fe"]] * n_rows), + "cartesian_site_positions": pa.array([[[2.0, 2.0, 2.0]]] * n_rows), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] * n_rows + ), + # extras: confirm they get ignored + "bader_charges": pa.array([[0.4]] * n_rows), + "material_id": pa.array([f"mat_{i:03d}" for i in range(n_rows)]), + } + ) + pq.write_table(table, path) + + +class TestProjectChunkContract: + """``project_chunk(in_path, out_path, basis_spec)`` -> None. + + Reads ``in_path`` (LeMat-Rho format), projects each row, writes + ``out_path`` in the schema documented at the top of this file. + """ + + def test_output_file_written(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=2) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + assert out.exists() + assert out.stat().st_size > 0 + + def test_row_count_matches_valid_input(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=3) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + t = pq.read_table(out) + assert len(t) == 3 + + def test_required_columns_present(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=2) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + t = pq.read_table(out) + required = { + "row_index", + "material_id", + "n_atoms", + "atomic_numbers", + "lattice_vectors", + "n_electrons", + "grid_shape", + "coefficients", + "basis_set_NMAPE", + } + missing = required - set(t.column_names) + assert not missing, f"missing required columns: {missing}" + + def test_coefficient_shape_per_row(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + spec = BasisSpec() + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=2) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, spec) + t = pq.read_table(out).to_pydict() + for c, n_atoms in zip(t["coefficients"], t["n_atoms"], strict=True): + # Each row has its own coefficient block; first dim is n_atoms, + # second is n_coeffs_per_atom. + arr = np.asarray(c) + assert arr.shape == (n_atoms, spec.n_coeffs_per_atom), ( + f"row coefficient shape mismatch: got {arr.shape}, " + f"expected ({n_atoms}, {spec.n_coeffs_per_atom})" + ) + + def test_basis_set_NMAPE_is_finite_and_nonnegative(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=3) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + t = pq.read_table(out).to_pydict() + for x in t["basis_set_NMAPE"]: + assert np.isfinite(x) + assert x >= 0.0 + + def test_material_id_preserved(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + _write_synthetic_chunk(d / "in.parquet", n_rows=3) + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + t = pq.read_table(out).to_pydict() + assert t["material_id"] == ["mat_000", "mat_001", "mat_002"] + + def test_handles_null_charge_density_rows(self): + """Real LeMat-Rho chunks have some rows with NULL density (failed + DFT extraction). Those should be skipped, not crash the projection. + """ + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_chunk + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + grids = [ + json.dumps(np.ones((10, 10, 10)).tolist()), + None, # null density - should be skipped + json.dumps(np.ones((10, 10, 10)).tolist()), + ] + table = pa.table( + { + "compressed_charge_density": pa.array(grids, type=pa.string()), + "species_at_sites": pa.array([["Fe"]] * 3), + "cartesian_site_positions": pa.array([[[2.0, 2.0, 2.0]]] * 3), + "lattice_vectors": pa.array( + [[[4.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 4.0]]] * 3 + ), + "material_id": pa.array(["a", "b", "c"]), + } + ) + pq.write_table(table, d / "in.parquet") + out = d / "out.parquet" + project_chunk(d / "in.parquet", out, BasisSpec()) + t = pq.read_table(out).to_pydict() + assert len(t["row_index"]) == 2 + assert t["row_index"] == [0, 2] + + +class TestProjectDirectory: + """Driver that runs project_chunk over every chunk_*.parquet in a dir.""" + + def test_processes_all_chunks(self): + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_directory + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + in_d = d / "in" + in_d.mkdir() + out_d = d / "out" + for i in range(3): + _write_synthetic_chunk(in_d / f"chunk_{i:06d}.parquet", n_rows=2) + project_directory(in_d, out_d, BasisSpec()) + outputs = sorted(out_d.glob("chunk_*.parquet")) + assert len(outputs) == 3 + for out in outputs: + assert pq.read_table(out).num_rows == 2 + + def test_skips_existing_outputs(self): + """Idempotent: a re-run does not re-project chunks that already exist. + + Lets us resume a partially-completed projection job after an + interruption without paying the LSQR cost again. + """ + from salted_ft.basis import BasisSpec + from salted_ft.project_dataset import project_directory + + with tempfile.TemporaryDirectory() as tmp: + d = Path(tmp) + in_d = d / "in" + in_d.mkdir() + out_d = d / "out" + _write_synthetic_chunk(in_d / "chunk_000000.parquet", n_rows=2) + # First run + project_directory(in_d, out_d, BasisSpec()) + first_mtime = (out_d / "chunk_000000.parquet").stat().st_mtime + # Second run should be a no-op + project_directory(in_d, out_d, BasisSpec()) + second_mtime = (out_d / "chunk_000000.parquet").stat().st_mtime + assert first_mtime == second_mtime diff --git a/tests/test_salted_projection.py b/tests/test_salted_projection.py new file mode 100644 index 0000000..fc79801 --- /dev/null +++ b/tests/test_salted_projection.py @@ -0,0 +1,225 @@ +"""TDD tests for VASP CHGCAR <-> SALTED basis projection / reconstruction. + +These two operations are the DIY bridge layer between VASP plane-wave +densities and the rholearn/SALTED localized-basis world (see the +``plan_salted_graph2mat_basis_choice_may_20_pm.md`` memo for context). + +Locked contracts here: + +* ``project_chgcar_to_basis(density, atoms, basis_spec)`` + -> ``np.ndarray (n_atoms, n_coeffs_per_atom)`` float64. + Zero density gives zero coefficients. Linear in the input density. + +* ``reconstruct_grid_from_basis(coefficients, atoms, grid_shape, basis_spec)`` + -> ``np.ndarray (Nx, Ny, Nz)`` float64. + Zero coefficients give a zero grid. Linear in the coefficients. + A single-atom, l=0, n=0 unit coefficient produces a Gaussian peaked + at the atom position. + +The roundtrip is intentionally NOT pinned to high accuracy in this PR. +A simple orthonormal-approximation projection is enough to land the +contract; a future PR will swap in least-squares solving against the +full basis overlap matrix for tight roundtrip accuracy. +""" + +from __future__ import annotations + +import ase +import numpy as np + + +# --------------------------------------------------------------------------- +# Helpers — small synthetic structures so tests stay fast and inspectable. +# --------------------------------------------------------------------------- +def _cubic_atoms(symbols=("Fe",), fractional=((0.0, 0.0, 0.0),), a=4.0): + """Single-cell ase.Atoms with the requested species/positions in fractional coords.""" + cell = np.eye(3) * a + cart = np.array(fractional) @ cell + return ase.Atoms(symbols=list(symbols), positions=cart, cell=cell, pbc=True) + + +def _zero_grid(shape=(8, 8, 8)) -> np.ndarray: + return np.zeros(shape, dtype=np.float32) + + +def _random_grid(shape=(8, 8, 8), seed: int = 0) -> np.ndarray: + rng = np.random.default_rng(seed) + return rng.random(shape, dtype=np.float32) + + +# --------------------------------------------------------------------------- +# Projection: density grid -> coefficients +# --------------------------------------------------------------------------- +class TestProjectChgcarToBasis: + def test_output_shape_is_n_atoms_by_n_coeffs(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + spec = BasisSpec() + atoms = _cubic_atoms( + symbols=("Fe", "Fe"), fractional=((0.0, 0.0, 0.0), (0.5, 0.5, 0.5)) + ) + coeffs = project_chgcar_to_basis(_zero_grid(), atoms, spec) + assert coeffs.shape == (2, spec.n_coeffs_per_atom) + + def test_zero_density_gives_zero_coefficients(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + coeffs = project_chgcar_to_basis(_zero_grid(), _cubic_atoms(), BasisSpec()) + np.testing.assert_array_equal(coeffs, 0.0) + + def test_output_dtype_is_float64(self): + """float64 because we'll feed these to scipy/least-squares downstream.""" + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + coeffs = project_chgcar_to_basis(_random_grid(), _cubic_atoms(), BasisSpec()) + assert coeffs.dtype == np.float64 + + def test_linearity_in_density(self): + """project(alpha * rho) == alpha * project(rho); a basic sanity check + since both projection and reconstruction must be linear maps. + """ + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + atoms = _cubic_atoms() + spec = BasisSpec() + rho = _random_grid(seed=1) + c1 = project_chgcar_to_basis(rho, atoms, spec) + c_scaled = project_chgcar_to_basis(2.5 * rho, atoms, spec) + np.testing.assert_allclose(c_scaled, 2.5 * c1, rtol=1e-5, atol=1e-8) + + def test_additivity_in_density(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + atoms = _cubic_atoms() + spec = BasisSpec() + rho1 = _random_grid(seed=2) + rho2 = _random_grid(seed=3) + c1 = project_chgcar_to_basis(rho1, atoms, spec) + c2 = project_chgcar_to_basis(rho2, atoms, spec) + c_sum = project_chgcar_to_basis(rho1 + rho2, atoms, spec) + np.testing.assert_allclose(c_sum, c1 + c2, rtol=1e-5, atol=1e-8) + + def test_output_is_finite(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import project_chgcar_to_basis + + coeffs = project_chgcar_to_basis(_random_grid(), _cubic_atoms(), BasisSpec()) + assert np.isfinite(coeffs).all() + + +# --------------------------------------------------------------------------- +# Reconstruction: coefficients -> density grid +# --------------------------------------------------------------------------- +class TestReconstructGridFromBasis: + def test_output_shape_matches_grid_shape(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms() + coeffs = np.zeros((1, spec.n_coeffs_per_atom)) + grid = reconstruct_grid_from_basis(coeffs, atoms, (8, 8, 8), spec) + assert grid.shape == (8, 8, 8) + + def test_zero_coefficients_give_zero_grid(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms() + coeffs = np.zeros((1, spec.n_coeffs_per_atom)) + grid = reconstruct_grid_from_basis(coeffs, atoms, (8, 8, 8), spec) + np.testing.assert_array_equal(grid, 0.0) + + def test_output_dtype_is_float64(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms() + rng = np.random.default_rng(4) + coeffs = rng.standard_normal((1, spec.n_coeffs_per_atom)) + grid = reconstruct_grid_from_basis(coeffs, atoms, (8, 8, 8), spec) + assert grid.dtype == np.float64 + + def test_linearity_in_coefficients(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms() + rng = np.random.default_rng(5) + c = rng.standard_normal((1, spec.n_coeffs_per_atom)) + g1 = reconstruct_grid_from_basis(c, atoms, (8, 8, 8), spec) + g_scaled = reconstruct_grid_from_basis(3.0 * c, atoms, (8, 8, 8), spec) + np.testing.assert_allclose(g_scaled, 3.0 * g1, rtol=1e-5, atol=1e-8) + + def test_single_atom_l0_n0_peaks_at_atom_position(self): + """Unit s-coefficient on the first radial channel: density should peak + at the atom position (not somewhere else in the cell).""" + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + # Atom at the (0.5, 0.5, 0.5) interior point, away from cell edges. + atoms = _cubic_atoms(fractional=((0.5, 0.5, 0.5),), a=4.0) + coeffs = np.zeros((1, spec.n_coeffs_per_atom)) + coeffs[0, 0] = 1.0 # l=0, m=0, n=0 (the most localized s channel) + grid = reconstruct_grid_from_basis(coeffs, atoms, (16, 16, 16), spec) + + # Peak index in (i, j, k) integer grid should be near the center. + peak_idx = np.unravel_index(np.argmax(grid), grid.shape) + center = (8, 8, 8) # fractional 0.5 on a 16-point grid + for actual, expected in zip(peak_idx, center, strict=True): + assert abs(actual - expected) <= 1, ( + f"density peak {peak_idx} is far from atom (expected near {center}); " + "either the atom-position lookup or the basis evaluation is wrong" + ) + + def test_output_is_finite(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import reconstruct_grid_from_basis + + spec = BasisSpec() + atoms = _cubic_atoms() + rng = np.random.default_rng(6) + coeffs = rng.standard_normal((1, spec.n_coeffs_per_atom)) + grid = reconstruct_grid_from_basis(coeffs, atoms, (8, 8, 8), spec) + assert np.isfinite(grid).all() + + +# --------------------------------------------------------------------------- +# Roundtrip: project then reconstruct (and vice versa). +# --------------------------------------------------------------------------- +class TestProjectionReconstructionRoundtrip: + def test_roundtrip_of_zero_density_is_zero(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import ( + project_chgcar_to_basis, + reconstruct_grid_from_basis, + ) + + atoms = _cubic_atoms() + spec = BasisSpec() + coeffs = project_chgcar_to_basis(_zero_grid(), atoms, spec) + roundtrip = reconstruct_grid_from_basis(coeffs, atoms, (8, 8, 8), spec) + np.testing.assert_array_equal(roundtrip, 0.0) + + def test_roundtrip_of_zero_coefficients_is_zero(self): + from salted_ft.basis import BasisSpec + from salted_ft.projection import ( + project_chgcar_to_basis, + reconstruct_grid_from_basis, + ) + + atoms = _cubic_atoms() + spec = BasisSpec() + c = np.zeros((1, spec.n_coeffs_per_atom)) + grid = reconstruct_grid_from_basis(c, atoms, (8, 8, 8), spec) + c_back = project_chgcar_to_basis(grid, atoms, spec) + np.testing.assert_array_equal(c_back, 0.0) diff --git a/tests/test_salted_rholearn_adapter.py b/tests/test_salted_rholearn_adapter.py new file mode 100644 index 0000000..4e77775 --- /dev/null +++ b/tests/test_salted_rholearn_adapter.py @@ -0,0 +1,235 @@ +"""TDD tests for the SALTED -> rholearn data adapter (Phase D3). + +rholearn's training loop consumes basis-coefficient vectors in a +specific flat layout (see ``rholearn/utils/convert.py::_get_flat_index``): + + atom (outer) -> o3_lambda -> n (radial, INNER to lambda) -> o3_mu (innermost) + +Our ``salted_ft.projection`` layout differs: + + atom (outer) -> n (radial, OUTER to lambda) -> (lambda, mu) packed + +The adapter functions tested here move between the two layouts and +produce the ``lmax`` / ``nmax`` dicts rholearn's metatensor converter +needs to know the basis shape. +""" + +from __future__ import annotations + +import numpy as np +import pytest + + +# --------------------------------------------------------------------------- +# rholearn's lmax / nmax dict format (from rholearn/utils/convert.py docstrings) +# +# lmax = {"H": 1, "C": 2} per-species max lambda +# nmax = {("H", 0): 2, ("H", 1): 3, ("C", 0): 4, ...} per-species per-lambda n +# +# Our uniform BasisSpec has max_l + n_radial constant across species. The +# adapter expands that into rholearn's per-species dicts so the same basis +# spec works for arbitrary species sets. +# --------------------------------------------------------------------------- + + +class TestBuildLmaxNmaxDicts: + """Convert our uniform BasisSpec into rholearn's per-species dicts.""" + + def test_lmax_contains_every_species(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import build_lmax_nmax + + lmax, nmax = build_lmax_nmax(BasisSpec(), species=("H", "O", "Fe")) + assert set(lmax) == {"H", "O", "Fe"} + + def test_lmax_value_matches_basis_spec(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import build_lmax_nmax + + spec = BasisSpec() + lmax, _ = build_lmax_nmax(spec, species=("Fe",)) + assert lmax["Fe"] == spec.max_l + + def test_nmax_keyed_by_species_and_lambda(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import build_lmax_nmax + + spec = BasisSpec() + _, nmax = build_lmax_nmax(spec, species=("H", "Fe")) + # Both species share the same n_radial at every lambda + for s in ("H", "Fe"): + for lam in range(spec.max_l + 1): + assert nmax[(s, lam)] == spec.n_radial, ( + f"nmax[({s!r}, {lam})] must be {spec.n_radial}, " + f"got {nmax[(s, lam)]}" + ) + + def test_total_per_atom_coefficients_match(self): + """Sum of ``(2*l + 1) * nmax[(s, l)]`` across l must equal + ``BasisSpec.n_coeffs_per_atom``. If this drifts the flat vector + produced by the adapter will be the wrong length. + """ + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import build_lmax_nmax + + spec = BasisSpec() + lmax, nmax = build_lmax_nmax(spec, species=("Fe",)) + total = sum((2 * lam + 1) * nmax[("Fe", lam)] for lam in range(lmax["Fe"] + 1)) + assert total == spec.n_coeffs_per_atom + + +# --------------------------------------------------------------------------- +# Reordering: our (atom, n_outer, lm_packed) <-> rholearn (atom, l, n, mu). +# Pure ndarray math, no metatensor required. +# --------------------------------------------------------------------------- + + +class TestDenseToRholearnFlat: + """``dense_to_rholearn_flat(coeffs, basis_spec, symbols) -> np.ndarray``.""" + + def test_output_length_matches_total_basis(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_rholearn_flat + + spec = BasisSpec() + atoms = ("Fe", "Fe") + coeffs = np.zeros((2, spec.n_coeffs_per_atom)) + flat = dense_to_rholearn_flat(coeffs, spec, atoms) + assert flat.shape == (2 * spec.n_coeffs_per_atom,) + + def test_zero_in_gives_zero_out(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_rholearn_flat + + spec = BasisSpec() + flat = dense_to_rholearn_flat( + np.zeros((1, spec.n_coeffs_per_atom)), spec, ("Fe",) + ) + np.testing.assert_array_equal(flat, 0.0) + + def test_dtype_preserved(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_rholearn_flat + + spec = BasisSpec() + rng = np.random.default_rng(0) + coeffs = rng.standard_normal((1, spec.n_coeffs_per_atom)).astype(np.float64) + flat = dense_to_rholearn_flat(coeffs, spec, ("Fe",)) + assert flat.dtype == np.float64 + + def test_concatenates_atoms_in_order(self): + """Per-atom blocks must appear in input order (atom 0 first, then 1, ...).""" + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_rholearn_flat + + spec = BasisSpec() + # Use distinguishable per-atom values + coeffs = np.zeros((2, spec.n_coeffs_per_atom)) + coeffs[0, :] = 1.0 + coeffs[1, :] = 2.0 + flat = dense_to_rholearn_flat(coeffs, spec, ("Fe", "Fe")) + per_atom = spec.n_coeffs_per_atom + assert np.allclose(flat[:per_atom], 1.0) + assert np.allclose(flat[per_atom:], 2.0) + + +class TestRoundtrip: + """dense -> rholearn-flat -> dense must be exactly the identity.""" + + def test_roundtrip_random_single_atom(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import ( + dense_to_rholearn_flat, + rholearn_flat_to_dense, + ) + + spec = BasisSpec() + rng = np.random.default_rng(1) + coeffs = rng.standard_normal((1, spec.n_coeffs_per_atom)) + flat = dense_to_rholearn_flat(coeffs, spec, ("Fe",)) + restored = rholearn_flat_to_dense(flat, spec, ("Fe",)) + np.testing.assert_array_equal(restored, coeffs) + + def test_roundtrip_random_multi_atom(self): + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import ( + dense_to_rholearn_flat, + rholearn_flat_to_dense, + ) + + spec = BasisSpec() + rng = np.random.default_rng(2) + symbols = ("Fe", "O", "Fe", "H") + coeffs = rng.standard_normal((len(symbols), spec.n_coeffs_per_atom)) + flat = dense_to_rholearn_flat(coeffs, spec, symbols) + restored = rholearn_flat_to_dense(flat, spec, symbols) + np.testing.assert_array_equal(restored, coeffs) + + def test_permutation_is_actually_nontrivial(self): + """The reordering must MOVE values around -- if dense -> flat were + the identity that would mean we'd silently fed misordered data to + rholearn. Pinning this catches a future 'simplification' that + accidentally drops the permutation. + """ + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_rholearn_flat + + spec = BasisSpec() + # Distinguishable per-channel values via arange + coeffs = np.arange(spec.n_coeffs_per_atom, dtype=np.float64).reshape( + 1, spec.n_coeffs_per_atom + ) + flat = dense_to_rholearn_flat(coeffs, spec, ("Fe",)) + # rholearn's ordering is atom -> lambda -> n -> mu; ours is + # atom -> n -> lambda -> mu. So flat[0] is c[atom=0, lambda=0, n=0, mu=0] + # which in OUR layout is at position [n=0, lm=0] = 0. So flat[0] == 0. + # But flat[1] is c[atom=0, lambda=1, n=0, mu=-1] which in OUR layout + # is at [n=0, lm=1] = 1. flat[1] == 1. + # The DIFFERENT ordering kicks in for flat[3]: rholearn says lambda=1 + # n=1 mu=-1, which in ours is at [n=1, lm=1] = 25, not 3. + # So flat[3] != coeffs[0, 3] is the load-bearing check. + assert flat[3] != coeffs[0, 3], ( + "ordering is trivial; the reordering should move values around" + ) + + +# --------------------------------------------------------------------------- +# Smoke test for the full TensorMap path. Heavier dependency on metatensor +# but the test is short. +# --------------------------------------------------------------------------- + + +class TestDenseToTensorMap: + """``dense_to_tensormap`` returns a metatensor TensorMap with the right keys. + + Requires the rholearn sibling repo at ``../rholearn/`` (auto-skips + when absent). On Adastra (where rholearn IS installed) this test + activates and exercises the full conversion path. + """ + + def test_tensormap_has_o3_lambda_center_type_keys(self): + pytest.importorskip("metatensor") + pytest.importorskip("chemfiles") + + from pathlib import Path + + if not (Path(__file__).resolve().parent.parent.parent / "rholearn").exists(): + pytest.skip("rholearn sibling repo not present; skipping live conversion") + + from salted_ft.basis import BasisSpec + from salted_ft.rholearn_adapter import dense_to_tensormap + + spec = BasisSpec() + positions = np.array([[0.0, 0.0, 0.0], [2.0, 2.0, 2.0]]) + cell = np.eye(3) * 4.0 + symbols = ("Fe", "Fe") + rng = np.random.default_rng(3) + coeffs = rng.standard_normal((2, spec.n_coeffs_per_atom)) + tmap = dense_to_tensormap( + coeffs, spec, symbols, positions, cell, structure_idx=0 + ) + # Keys must contain ``o3_lambda`` and ``center_type`` per rholearn's + # convention (see rholearn/utils/convert.py docstrings). + names = list(tmap.keys.names) + assert "o3_lambda" in names + assert "center_type" in names diff --git a/tests/test_scf_speedup_run.py b/tests/test_scf_speedup_run.py new file mode 100644 index 0000000..1ac3e1d --- /dev/null +++ b/tests/test_scf_speedup_run.py @@ -0,0 +1,596 @@ +"""TDD tests for ``scripts/scf_speedup_run.py`` (P4). + +The driver loops a held-out test parquet, predicts each row's +density via the chosen ML arm, writes a CHGCAR with the right +electron-count rescaling, and submits a paired baseline + predicted +VASP Flow via ``entalsim.dft.scf_speedup.make_scf_speedup_pair`` + +``entalsim.core.submit.submit_workflow``. + +Tests use dependency injection (``make_pair_fn`` and ``submit_fn``) +so they pass locally without entalsim installed. The real CLI +imports entalsim's functions at runtime. +""" + +from __future__ import annotations + +import importlib +import sys +from pathlib import Path +from types import SimpleNamespace + +import numpy as np +import pandas as pd +import pytest + + +@pytest.fixture +def run_module(): + scripts_dir = Path(__file__).resolve().parent.parent / "scripts" + if str(scripts_dir) not in sys.path: + sys.path.insert(0, str(scripts_dir)) + if "scf_speedup_run" in sys.modules: + del sys.modules["scf_speedup_run"] + return importlib.import_module("scf_speedup_run") + + +def _toy_parquet(tmp_path: Path, n_rows: int = 2) -> Path: + """Synthesise a held-out-split-shaped parquet. + + Columns mirror what the held-out split builder will emit: + material_id, atomic_numbers, positions (flat), lattice_vectors + (flat 9), grid_shape, n_electrons. + """ + rows = [] + grid_shape = (4, 4, 4) + for i in range(n_rows): + n_atoms = 2 + rows.append( + { + "material_id": f"mp-toy-{i}", + "n_atoms": n_atoms, + "atomic_numbers": np.array([1, 1], dtype=np.int64), + "positions": np.array( + [[0.0, 0.0, 0.0], [0.74 + 0.01 * i, 0.0, 0.0]], + dtype=np.float64, + ).reshape(-1), + "lattice_vectors": (np.eye(3) * 5.0).reshape(-1), + "grid_shape": np.array(grid_shape, dtype=np.int64), + "n_electrons": 2.0, + } + ) + out = tmp_path / "held_out.parquet" + pd.DataFrame(rows).to_parquet(out) + return out + + +def _fake_flow(n_jobs: int = 2): + return SimpleNamespace( + jobs=[SimpleNamespace(uuid=f"j{i}") for i in range(n_jobs)], + name="fake_flow", + ) + + +def _make_pair_mock(captured: list): + """Returns a (mock, captured) pair. ``captured`` records each call.""" + + def make_pair(structure, predicted_chgcar_dir, metadata): + captured.append( + { + "structure_formula": structure.composition.reduced_formula, + "predicted_chgcar_dir": str(predicted_chgcar_dir), + "metadata": dict(metadata), + "chgcar_exists": (Path(predicted_chgcar_dir) / "CHGCAR").exists(), + } + ) + return _fake_flow() + + return make_pair + + +def _submit_mock(captured: list): + def submit(flow, project, worker): + captured.append( + {"project": project, "worker": worker, "n_jobs": len(flow.jobs)} + ) + + return submit + + +class TestDriverBasics: + def test_dry_run_writes_one_chgcar_per_row(self, tmp_path, run_module): + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=2) + chgcar_dir = tmp_path / "chgcars" + make_calls: list = [] + submit_calls: list = [] + + records = run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="test_project", + worker="test_worker", + dry_run=True, + make_pair_fn=_make_pair_mock(make_calls), + submit_fn=_submit_mock(submit_calls), + ) + assert len(records) == 2 + for r in records: + assert Path(r["chgcar_path"]).exists() + assert submit_calls == [], "dry_run=True must not submit" + + def test_make_pair_invoked_with_metadata(self, tmp_path, run_module): + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=2) + chgcar_dir = tmp_path / "chgcars" + make_calls: list = [] + + run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="test_project", + worker="test_worker", + dry_run=True, + make_pair_fn=_make_pair_mock(make_calls), + submit_fn=_submit_mock([]), + ) + assert len(make_calls) == 2 + for call in make_calls: + md = call["metadata"] + assert md["experiment"] == "scf_speedup" + assert md["model"] == "salted" + assert md["material_id"].startswith("mp-toy-") + assert call["chgcar_exists"], ( + "make_scf_speedup_pair must see a real CHGCAR file at the path " + "we hand it; otherwise its FileNotFoundError fires on every row" + ) + + def test_limit_caps_rows_processed(self, tmp_path, run_module): + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=5) + chgcar_dir = tmp_path / "chgcars" + make_calls: list = [] + + records = run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="p", + worker="w", + limit=2, + dry_run=True, + make_pair_fn=_make_pair_mock(make_calls), + submit_fn=_submit_mock([]), + ) + assert len(records) == 2 + assert len(make_calls) == 2 + + +class TestSubmitWiring: + def test_non_dry_run_calls_submit_per_row(self, tmp_path, run_module): + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=2) + chgcar_dir = tmp_path / "chgcars" + submit_calls: list = [] + + run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="jz_scf_speedup", + worker="jean_zay_cpu", + dry_run=False, + make_pair_fn=_make_pair_mock([]), + submit_fn=_submit_mock(submit_calls), + ) + assert len(submit_calls) == 2 + for call in submit_calls: + assert call["project"] == "jz_scf_speedup" + assert call["worker"] == "jean_zay_cpu" + assert call["n_jobs"] == 2 + + def test_records_include_submitted_flag(self, tmp_path, run_module): + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=1) + chgcar_dir = tmp_path / "chgcars" + + dry = run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="p", + worker="w", + dry_run=True, + make_pair_fn=_make_pair_mock([]), + submit_fn=_submit_mock([]), + ) + wet = run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=tmp_path / "chgcars_wet", + basis_spec=BasisSpec(), + project="p", + worker="w", + dry_run=False, + make_pair_fn=_make_pair_mock([]), + submit_fn=_submit_mock([]), + ) + assert dry[0]["submitted"] is False + assert wet[0]["submitted"] is True + + +class TestArmCheckpointGuard: + def test_charge3net_without_ckpt_fails_fast(self, tmp_path, run_module): + """ChargE3Net and DeepDFT without a checkpoint run as random-init + models. Their predictions would be meaningless, and we would + silently waste HPC time. The driver must refuse before any + prediction or submit. + """ + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=1) + with pytest.raises(ValueError, match="ckpt"): + run_module.run_experiment( + model_name="charge3net", + test_parquet=in_parquet, + chgcar_dir=tmp_path / "c", + basis_spec=BasisSpec(), + project="p", + worker="w", + ckpt=None, + dry_run=True, + make_pair_fn=_make_pair_mock([]), + submit_fn=_submit_mock([]), + ) + + def test_deepdft_without_ckpt_fails_fast(self, tmp_path, run_module): + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=1) + with pytest.raises(ValueError, match="ckpt"): + run_module.run_experiment( + model_name="deepdft", + test_parquet=in_parquet, + chgcar_dir=tmp_path / "c", + basis_spec=BasisSpec(), + project="p", + worker="w", + ckpt=None, + dry_run=True, + make_pair_fn=_make_pair_mock([]), + submit_fn=_submit_mock([]), + ) + + def test_salted_without_ckpt_uses_stub(self, tmp_path, run_module): + """SALTED stub mode is the documented fallback. The driver must + let it through so we can dry-run the pipeline before D6 trained + weights are available.""" + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=1) + records = run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=tmp_path / "c", + basis_spec=BasisSpec(), + project="p", + worker="w", + ckpt=None, + dry_run=True, + make_pair_fn=_make_pair_mock([]), + submit_fn=_submit_mock([]), + ) + assert records[0]["ckpt"] == "stub" + + +class TestPerRowResilience: + """A multi-hour batch must not die on a single bad row.""" + + def test_per_row_failure_does_not_abort_loop(self, tmp_path, run_module): + """If row 2 of 3 has a corrupt cell (positions with wrong + length) the loop must skip it, record the failure, and keep + going. Otherwise the prior rows' Flows are submitted to + Mongo with no clean resume path.""" + from salted_ft.basis import BasisSpec + + # 3 rows, middle one has corrupt positions. + rows = [] + grid_shape = (4, 4, 4) + good_positions = np.array( + [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]], dtype=np.float64 + ).reshape(-1) + for i in range(3): + pos = good_positions + if i == 1: + # Length 2: positions.reshape(-1, 3) raises. + pos = np.array([0.0, 0.0], dtype=np.float64) + rows.append( + { + "material_id": f"mp-toy-{i}", + "n_atoms": 2, + "atomic_numbers": np.array([1, 1], dtype=np.int64), + "positions": pos, + "lattice_vectors": (np.eye(3) * 5.0).reshape(-1), + "grid_shape": np.array(grid_shape, dtype=np.int64), + "n_electrons": 2.0, + } + ) + in_parquet = tmp_path / "held_out_with_bad_row.parquet" + pd.DataFrame(rows).to_parquet(in_parquet) + + records = run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=tmp_path / "chgcars", + basis_spec=BasisSpec(), + project="p", + worker="w", + dry_run=True, + make_pair_fn=_make_pair_mock([]), + submit_fn=_submit_mock([]), + ) + assert len(records) == 3 + good = [r for r in records if r.get("error") is None] + bad = [r for r in records if r.get("error") is not None] + assert len(good) == 2 + assert len(bad) == 1 + assert bad[0]["material_id"] == "mp-toy-1" + assert bad[0]["submitted"] is False + assert "reshape" in bad[0]["error"] or "cannot" in bad[0]["error"] + + +class TestManifest: + def test_manifest_jsonl_written_after_each_row(self, tmp_path, run_module): + """The manifest must be written incrementally so an + interrupted run leaves a resumable record. After all rows + complete the manifest should have one JSONL line per row.""" + import json + + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=3) + chgcar_dir = tmp_path / "chgcars" + manifest = tmp_path / "manifest.jsonl" + + run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="p", + worker="w", + dry_run=True, + manifest_path=manifest, + make_pair_fn=_make_pair_mock([]), + submit_fn=_submit_mock([]), + ) + assert manifest.exists() + lines = manifest.read_text().splitlines() + assert len(lines) == 3 + for line in lines: + rec = json.loads(line) + assert "material_id" in rec + assert "model" in rec + + def test_manifest_defaults_to_chgcar_dir(self, tmp_path, run_module): + """If --manifest is not given, default to + chgcar_dir/manifest.jsonl so a re-run can find it by + convention.""" + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=1) + chgcar_dir = tmp_path / "chgcars" + + run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="p", + worker="w", + dry_run=True, + make_pair_fn=_make_pair_mock([]), + submit_fn=_submit_mock([]), + ) + assert (chgcar_dir / "manifest.jsonl").exists() + + +class TestSkipExisting: + def test_skip_existing_skips_already_submitted_rows(self, tmp_path, run_module): + """Pre-populate a manifest with one submitted row, then + re-run with skip_existing=True; only the unseen rows should + be processed.""" + import json + + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=3) + chgcar_dir = tmp_path / "chgcars" + chgcar_dir.mkdir() + manifest = chgcar_dir / "manifest.jsonl" + # Mark mp-toy-1 as already done. + manifest.write_text( + json.dumps( + { + "material_id": "mp-toy-1", + "model": "salted", + "submitted": True, + "error": None, + } + ) + + "\n" + ) + make_calls: list = [] + records = run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="p", + worker="w", + dry_run=True, + skip_existing=True, + manifest_path=manifest, + make_pair_fn=_make_pair_mock(make_calls), + submit_fn=_submit_mock([]), + ) + # mp-toy-1 should NOT have been re-processed. + processed_ids = {call["metadata"]["material_id"] for call in make_calls} + assert "mp-toy-1" not in processed_ids + assert processed_ids == {"mp-toy-0", "mp-toy-2"} + # Records reflect what THIS run did, not the historical entry. + assert len(records) == 2 + + def test_skip_existing_does_not_skip_failed_rows(self, tmp_path, run_module): + """A row in the manifest with submitted=False (error from a + previous run) should be retried on the next run, not skipped.""" + import json + + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=2) + chgcar_dir = tmp_path / "chgcars" + chgcar_dir.mkdir() + manifest = chgcar_dir / "manifest.jsonl" + manifest.write_text( + json.dumps( + { + "material_id": "mp-toy-0", + "model": "salted", + "submitted": False, + "error": "previous_run_died", + } + ) + + "\n" + ) + make_calls: list = [] + run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="p", + worker="w", + dry_run=True, + skip_existing=True, + manifest_path=manifest, + make_pair_fn=_make_pair_mock(make_calls), + submit_fn=_submit_mock([]), + ) + processed_ids = {call["metadata"]["material_id"] for call in make_calls} + # mp-toy-0 was previously failed, should be retried. + assert "mp-toy-0" in processed_ids + + +class TestChgcarOrganisation: + def test_per_row_chgcar_dirs_are_unique(self, tmp_path, run_module): + """make_scf_speedup_pair takes a directory and stages CHGCAR + from it. Multiple rows must NOT share one directory or the + last write wins.""" + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=3) + chgcar_dir = tmp_path / "chgcars" + make_calls: list = [] + + records = run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="p", + worker="w", + dry_run=True, + make_pair_fn=_make_pair_mock(make_calls), + submit_fn=_submit_mock([]), + ) + seen = {Path(call["predicted_chgcar_dir"]).resolve() for call in make_calls} + assert len(seen) == len(records) == 3 + + def test_chgcar_layout_is_nested_by_model_then_material_id( + self, tmp_path, run_module + ): + """Layout must be ``chgcar_root///CHGCAR`` + so a material_id containing separator characters never causes + ambiguity. Was previously a flat ``{model}__{material_id}/`` + which broke on synthesised IDs like ``oqmd__1234``.""" + from salted_ft.basis import BasisSpec + + in_parquet = _toy_parquet(tmp_path, n_rows=1) + chgcar_dir = tmp_path / "chgcars" + + records = run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=chgcar_dir, + basis_spec=BasisSpec(), + project="p", + worker="w", + dry_run=True, + make_pair_fn=_make_pair_mock([]), + submit_fn=_submit_mock([]), + ) + chgcar_path = Path(records[0]["chgcar_path"]) + # Path tail must be ...///CHGCAR + parts = chgcar_path.parts + assert parts[-1] == "CHGCAR" + assert parts[-2] == "mp-toy-0" + assert parts[-3] == "salted" + + +class TestRealisticRow: + """Catch mutation-killers a 2-atom H2 toy row misses: a missing + n_electrons rescale, a positions-reshape bug, or a grid/atom + mismatch all pass silently on the degenerate fixture.""" + + def test_5_atom_asymmetric_grid_unequal_n_electrons(self, tmp_path, run_module): + from salted_ft.basis import BasisSpec + + # 5 atoms: 1 Fe + 4 O (chosen so sum(Z)=26+4*8=58 != n_electrons=12.5). + # Asymmetric grid_shape catches axes-swap bugs. + n_atoms = 5 + atomic_numbers = np.array([26, 8, 8, 8, 8], dtype=np.int64) + rng = np.random.default_rng(0) + positions = rng.uniform(0, 5, size=(n_atoms, 3)).astype(np.float64) + rows = [ + { + "material_id": "mp-realistic-0", + "n_atoms": n_atoms, + "atomic_numbers": atomic_numbers, + "positions": positions.reshape(-1), + "lattice_vectors": (np.eye(3) * 5.0).reshape(-1), + "grid_shape": np.array([8, 10, 12], dtype=np.int64), + "n_electrons": 12.5, + } + ] + in_parquet = tmp_path / "realistic.parquet" + pd.DataFrame(rows).to_parquet(in_parquet) + + make_calls: list = [] + records = run_module.run_experiment( + model_name="salted", + test_parquet=in_parquet, + chgcar_dir=tmp_path / "chgcars", + basis_spec=BasisSpec(), + project="p", + worker="w", + dry_run=True, + make_pair_fn=_make_pair_mock(make_calls), + submit_fn=_submit_mock([]), + ) + # The row completed without error -- reshape correct, write_chgcar + # accepted asymmetric grid, n_electrons propagated to write_chgcar. + assert len(records) == 1 + assert records[0]["error"] is None, f"unexpected error: {records[0]['error']}" + assert records[0]["submitted"] is False # dry-run diff --git a/tests/test_submit_script.py b/tests/test_submit_script.py new file mode 100644 index 0000000..419fd38 --- /dev/null +++ b/tests/test_submit_script.py @@ -0,0 +1,184 @@ +"""TDD tests for the parameterized Adastra submit script. + +The script `submit_charge3net_adastra.sh` is now configurable via two env +vars: + + LEMATRHO_TRAINING_MODE "pretrained" (default) or "from_scratch" + LEMATRHO_DRY_RUN "1" prints the resolved train command and exits + +These tests pin the contract. + +They don't depend on Adastra. The script is sourced under bash with +LEMATRHO_DRY_RUN=1 so the venv activate / rocm-smi / srun calls are +skipped and the train invocation is printed instead of executed. +""" + +from __future__ import annotations + +import os +import shutil +import subprocess +from pathlib import Path + +import pytest + + +SUBMIT_SCRIPT = Path(__file__).resolve().parent.parent / "submit_charge3net_adastra.sh" + + +def _run(env_extra: dict) -> subprocess.CompletedProcess: + """Run the submit script under bash with LEMATRHO_DRY_RUN=1.""" + if shutil.which("bash") is None: + pytest.skip("bash not available in test environment") + env = { + **os.environ, + "LEMATRHO_DRY_RUN": "1", + # Avoid touching the user's real Adastra setup or W&B credentials. + "LEMATRHO_ADASTRA_SETUP": "/tmp/fake_setup_for_tests", + # SLURM env vars that the script would normally inherit. + "SLURM_NTASKS": "4", + "SLURM_NODELIST": "g0001", + "SLURM_JOB_ACCOUNT": "c1816212_mi250", + **env_extra, + } + return subprocess.run( + ["bash", str(SUBMIT_SCRIPT)], + env=env, + capture_output=True, + text=True, + check=False, + ) + + +def test_dry_run_mode_prints_train_command(): + """LEMATRHO_DRY_RUN=1 must print the resolved train command and exit 0.""" + result = _run({}) + assert result.returncode == 0, ( + f"dry-run exited {result.returncode}; stderr={result.stderr}" + ) + assert "charge3net_ft.train" in result.stdout, ( + f"dry-run output missing the train invocation; stdout={result.stdout}" + ) + + +def test_default_mode_is_pretrained(): + """Unset LEMATRHO_TRAINING_MODE -> pretrained MP checkpoint path is used.""" + result = _run({}) + assert result.returncode == 0 + out = result.stdout + assert "--ckpt-path" in out, ( + f"default (pretrained) run must pass --ckpt-path; stdout={out}" + ) + assert "charge3net_mp.pt" in out, ( + f"default run must point --ckpt-path at the MP checkpoint; stdout={out}" + ) + + +def test_pretrained_mode_uses_default_save_dir(): + """Pretrained mode writes to charge3net_checkpoints/ (no fromscratch suffix).""" + result = _run({"LEMATRHO_TRAINING_MODE": "pretrained"}) + assert result.returncode == 0 + assert ( + "charge3net_checkpoints " in (result.stdout + " ") + or "charge3net_checkpoints\n" in result.stdout + or "/charge3net_checkpoints" in result.stdout + ) + assert "charge3net_checkpoints_fromscratch" not in result.stdout, ( + f"pretrained mode must NOT use the fromscratch save dir; stdout={result.stdout}" + ) + + +def test_from_scratch_mode_drops_ckpt_path(): + """LEMATRHO_TRAINING_MODE=from_scratch -> no --ckpt-path flag at all. + + Without --ckpt-path, ChargE3NetWrapper.__init__ initializes weights + fresh (no MP transfer). This is the comparison arm for the + pretrained vs from-scratch experiment. + """ + result = _run({"LEMATRHO_TRAINING_MODE": "from_scratch"}) + assert result.returncode == 0, ( + f"from_scratch run exited {result.returncode}; stderr={result.stderr}" + ) + out = result.stdout + assert "--ckpt-path" not in out, ( + f"from_scratch must not pass --ckpt-path; stdout={out}" + ) + # also confirm charge3net_mp.pt isn't referenced anywhere in the + # resolved command (defense against accidental partial passing) + assert "charge3net_mp.pt" not in out, ( + f"from_scratch must not reference the MP checkpoint; stdout={out}" + ) + + +def test_from_scratch_mode_uses_separate_save_dir(): + """From-scratch run writes to a different dir so checkpoints don't collide + with the pretrained run. + """ + result = _run({"LEMATRHO_TRAINING_MODE": "from_scratch"}) + assert result.returncode == 0 + out = result.stdout + assert "charge3net_checkpoints_fromscratch" in out, ( + f"from_scratch must write to charge3net_checkpoints_fromscratch/; stdout={out}" + ) + + +def test_from_scratch_mode_uses_distinct_wandb_name(): + """W&B run name differs between the two modes so the dashboard tells them apart.""" + # WANDB_NAME is what wandb reads at init time when no --name is passed. + pretrained = _run({"LEMATRHO_TRAINING_MODE": "pretrained"}).stdout + fromscratch = _run({"LEMATRHO_TRAINING_MODE": "from_scratch"}).stdout + # Both must mention WANDB_NAME or set it somehow. + assert "WANDB_NAME" in pretrained or "wandb-run-name" in pretrained, ( + f"pretrained mode must set the wandb run name; stdout={pretrained}" + ) + assert "WANDB_NAME" in fromscratch or "wandb-run-name" in fromscratch, ( + f"from_scratch mode must set the wandb run name; stdout={fromscratch}" + ) + + # And they must differ. + # Extract WANDB_NAME value from each (simple regex-free parsing). + def _wandb_name(blob: str) -> str: + for line in blob.splitlines(): + if "WANDB_NAME=" in line: + return line.split("WANDB_NAME=", 1)[1].split()[0].strip("'\"") + return "" + + p_name = _wandb_name(pretrained) + f_name = _wandb_name(fromscratch) + assert p_name and f_name and p_name != f_name, ( + f"WANDB_NAME must differ between modes; pretrained={p_name!r}, fromscratch={f_name!r}" + ) + + +def test_invalid_mode_exits_with_clear_error(): + """An unrecognized mode must fail fast with a helpful message.""" + result = _run({"LEMATRHO_TRAINING_MODE": "garbage"}) + assert result.returncode != 0, ( + f"invalid mode must exit non-zero; stdout={result.stdout} stderr={result.stderr}" + ) + combined = (result.stdout + " " + result.stderr).lower() + assert "garbage" in combined or "training_mode" in combined or "mode" in combined, ( + f"error message should mention the bad value or the env var; " + f"stdout={result.stdout} stderr={result.stderr}" + ) + + +def test_batch_size_and_val_probes_match_paper(): + """Regression test: per-GPU batch=16, val_probes=1000 match the upstream paper.""" + result = _run({}) + assert "--batch-size 16" in result.stdout, ( + f"per-GPU batch must be 16 (paper); stdout={result.stdout}" + ) + assert "--val-probes 1000" in result.stdout, ( + f"val_probes must be 1000 (paper); stdout={result.stdout}" + ) + + +def test_wandb_mode_is_offline(): + """W&B must default to offline; api.wandb.ai is unreachable from + Adastra compute nodes (caused job 4969727 to crash after 1h47m). + """ + result = _run({}) + assert "--wandb-mode offline" in result.stdout, ( + f"wandb-mode must default to offline; stdout={result.stdout}" + ) diff --git a/uv.lock b/uv.lock index ace9d7d..886ca90 100644 --- a/uv.lock +++ b/uv.lock @@ -2,10 +2,14 @@ version = 1 revision = 3 requires-python = ">=3.11" resolution-markers = [ - "python_full_version >= '3.14' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", - "python_full_version >= '3.12' and python_full_version < '3.14' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", - "python_full_version >= '3.12' and platform_python_implementation == 'PyPy' and sys_platform == 'win32'", - "python_full_version < '3.12' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine == 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and platform_machine == 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and platform_machine != 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and platform_machine == 'ARM64' and platform_python_implementation == 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and platform_machine != 'ARM64' and platform_python_implementation == 'PyPy' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine == 'ARM64' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine != 'ARM64' and sys_platform == 'win32'", "python_full_version >= '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'win32'", "python_full_version >= '3.12' and python_full_version < '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'win32'", "python_full_version >= '3.12' and platform_python_implementation == 'PyPy' and sys_platform != 'win32'", @@ -110,6 +114,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/6a/bc7e17a3e87a2985d3e8f4da4cd0f481060eb78fb08596c42be62c90a4d9/aiosignal-1.3.2-py2.py3-none-any.whl", hash = "sha256:45cde58e409a301715980c2b01d0c28bdde3770d8290b5eb2173759d9acb31a5", size = 7597, upload-time = "2024-12-13T17:10:38.469Z" }, ] +[[package]] +name = "annotated-doc" +version = "0.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" }, +] + [[package]] name = "annotated-types" version = "0.7.0" @@ -454,6 +467,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/fc/6a8cb64e5f0324877d503c854da15d76c1e50eb722e320b15345c4d0c6de/cffi-1.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:f6a16c31041f09ead72d69f583767292f750d24913dadacf5756b966aacb3f1a", size = 182009, upload-time = "2024-09-04T20:44:45.309Z" }, ] +[[package]] +name = "cftime" +version = "1.6.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/65/dc/470ffebac2eb8c54151eb893055024fe81b1606e7c6ff8449a588e9cd17f/cftime-1.6.5.tar.gz", hash = "sha256:8225fed6b9b43fb87683ebab52130450fc1730011150d3092096a90e54d1e81e", size = 326605, upload-time = "2025-10-13T18:56:26.352Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e4/f6/9da7aba9548ede62d25936b8b448acd7e53e5dcc710896f66863dcc9a318/cftime-1.6.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:474e728f5a387299418f8d7cb9c52248dcd5d977b2a01de7ec06bba572e26b02", size = 512733, upload-time = "2025-10-13T18:56:00.189Z" }, + { url = "https://files.pythonhosted.org/packages/1f/d5/d86ad95fc1fd89947c34b495ff6487b6d361cf77500217423b4ebcb1f0c2/cftime-1.6.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ab9e80d4de815cac2e2d88a2335231254980e545d0196eb34ee8f7ed612645f1", size = 492946, upload-time = "2025-10-13T18:56:01.262Z" }, + { url = "https://files.pythonhosted.org/packages/4f/93/d7e8dd76b03a9d5be41a3b3185feffc7ea5359228bdffe7aa43ac772a75b/cftime-1.6.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ad24a563784e4795cb3d04bd985895b5db49ace2cbb71fcf1321fd80141f9a52", size = 1689856, upload-time = "2025-10-13T19:39:12.873Z" }, + { url = "https://files.pythonhosted.org/packages/3e/8d/86586c0d75110f774e46e2bd6d134e2d1cca1dedc9bb08c388fa3df76acd/cftime-1.6.5-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a3cda6fd12c7fb25eff40a6a857a2bf4d03e8cc71f80485d8ddc65ccbd80f16a", size = 1718573, upload-time = "2025-10-13T18:56:02.788Z" }, + { url = "https://files.pythonhosted.org/packages/bb/fe/7956914cfc135992e89098ebbc67d683c51ace5366ba4b114fef1de89b21/cftime-1.6.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:28cda78d685397ba23d06273b9c916c3938d8d9e6872a537e76b8408a321369b", size = 1788563, upload-time = "2025-10-13T18:56:04.075Z" }, + { url = "https://files.pythonhosted.org/packages/e5/c7/6669708fcfe1bb7b2a7ce693b8cc67165eac00d3ac5a5e8f6ce1be551ff9/cftime-1.6.5-cp311-cp311-win_amd64.whl", hash = "sha256:93ead088e3a216bdeb9368733a0ef89a7451dfc1d2de310c1c0366a56ad60dc8", size = 473631, upload-time = "2025-10-13T18:56:05.159Z" }, + { url = "https://files.pythonhosted.org/packages/82/c5/d70cb1ab533ca790d7c9b69f98215fa4fead17f05547e928c8f2b8f96e54/cftime-1.6.5-cp311-cp311-win_arm64.whl", hash = "sha256:3384d69a0a7f3d45bded21a8cbcce66c8ba06c13498eac26c2de41b1b9b6e890", size = 459383, upload-time = "2026-01-02T21:16:47.317Z" }, + { url = "https://files.pythonhosted.org/packages/b6/c1/e8cb7f78a3f87295450e7300ebaecf83076d96a99a76190593d4e1d2be40/cftime-1.6.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:eef25caed5ebd003a38719bd3ff8847cd52ef2ea56c3ebdb2c9345ba131fc7c5", size = 504175, upload-time = "2025-10-13T18:56:06.398Z" }, + { url = "https://files.pythonhosted.org/packages/50/1a/86e1072b09b2f9049bb7378869f64b6747f96a4f3008142afed8955b52a4/cftime-1.6.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c87d2f3b949e45463e559233c69e6a9cf691b2b378c1f7556166adfabbd1c6b0", size = 485980, upload-time = "2025-10-13T18:56:08.669Z" }, + { url = "https://files.pythonhosted.org/packages/35/28/d3177b60da3f308b60dee2aef2eb69997acfab1e863f0bf0d2a418396ce5/cftime-1.6.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:82cb413973cc51b55642b3a1ca5b28db5b93a294edbef7dc049c074b478b4647", size = 1591166, upload-time = "2025-10-13T19:39:14.109Z" }, + { url = "https://files.pythonhosted.org/packages/d1/fd/a7266970312df65e68b5641b86e0540a739182f5e9c62eec6dbd29f18055/cftime-1.6.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:85ba8e7356d239cfe56ef7707ac30feaf67964642ac760a82e507ee3c5db4ac4", size = 1642614, upload-time = "2025-10-13T18:56:09.815Z" }, + { url = "https://files.pythonhosted.org/packages/c4/73/f0035a4bc2df8885bb7bd5fe63659686ea1ec7d0cc74b4e3d50e447402e5/cftime-1.6.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:456039af7907a3146689bb80bfd8edabd074c7f3b4eca61f91b9c2670addd7ad", size = 1688090, upload-time = "2025-10-13T18:56:11.442Z" }, + { url = "https://files.pythonhosted.org/packages/88/15/8856a0ab76708553ff597dd2e617b088c734ba87dc3fd395e2b2f3efffe8/cftime-1.6.5-cp312-cp312-win_amd64.whl", hash = "sha256:da84534c43699960dc980a9a765c33433c5de1a719a4916748c2d0e97a071e44", size = 464840, upload-time = "2025-10-13T18:56:12.506Z" }, + { url = "https://files.pythonhosted.org/packages/3a/85/451009a986d9273d2208fc0898aa00262275b5773259bf3f942f6716a9e7/cftime-1.6.5-cp312-cp312-win_arm64.whl", hash = "sha256:c62cd8db9ea40131eea7d4523691c5d806d3265d31279e4a58574a42c28acd77", size = 450534, upload-time = "2026-01-02T21:16:48.784Z" }, + { url = "https://files.pythonhosted.org/packages/2e/60/74ea344b3b003fada346ed98a6899085d6fd4c777df608992d90c458fda6/cftime-1.6.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4aba66fd6497711a47c656f3a732c2d1755ad15f80e323c44a8716ebde39ddd5", size = 502453, upload-time = "2025-10-13T18:56:13.545Z" }, + { url = "https://files.pythonhosted.org/packages/1e/14/adb293ac6127079b49ff11c05cf3d5ce5c1f17d097f326dc02d74ddfcb6e/cftime-1.6.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:89e7cba699242366e67d6fb5aee579440e791063f92a93853610c91647167c0d", size = 484541, upload-time = "2025-10-13T18:56:14.612Z" }, + { url = "https://files.pythonhosted.org/packages/4f/74/bb8a4566af8d0ef3f045d56c462a9115da4f04b07c7fbbf2b4875223eebd/cftime-1.6.5-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2f1eb43d7a7b919ec99aee709fb62ef87ef1cf0679829ef93d37cc1c725781e9", size = 1591014, upload-time = "2025-10-13T19:39:15.346Z" }, + { url = "https://files.pythonhosted.org/packages/ba/08/52f06ff2f04d376f9cd2c211aefcf2b37f1978e43289341f362fc99f6a0e/cftime-1.6.5-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e02a1d80ffc33fe469c7db68aa24c4a87f01da0c0c621373e5edadc92964900b", size = 1633625, upload-time = "2025-10-13T18:56:15.745Z" }, + { url = "https://files.pythonhosted.org/packages/cf/33/03e0b23d58ea8fab94ecb4f7c5b721e844a0800c13694876149d98830a73/cftime-1.6.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:18ab754805233cdd889614b2b3b86a642f6d51a57a1ec327c48053f3414f87d8", size = 1684269, upload-time = "2025-10-13T18:56:17.04Z" }, + { url = "https://files.pythonhosted.org/packages/a4/60/a0cfba63847b43599ef1cdbbf682e61894994c22b9a79fd9e1e8c7e9de41/cftime-1.6.5-cp313-cp313-win_amd64.whl", hash = "sha256:6c27add8f907f4a4cd400e89438f2ea33e2eb5072541a157a4d013b7dbe93f9c", size = 465364, upload-time = "2025-10-13T18:56:18.05Z" }, + { url = "https://files.pythonhosted.org/packages/3d/e8/ec32f2aef22c15604e6fda39ff8d581a00b5469349f8fba61640d5358d2c/cftime-1.6.5-cp313-cp313-win_arm64.whl", hash = "sha256:31d1ff8f6bbd4ca209099d24459ec16dea4fb4c9ab740fbb66dd057ccbd9b1b9", size = 450468, upload-time = "2026-01-02T21:16:50.193Z" }, + { url = "https://files.pythonhosted.org/packages/ea/6c/a9618f589688358e279720f5c0fe67ef0077fba07334ce26895403ebc260/cftime-1.6.5-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c69ce3bdae6a322cbb44e9ebc20770d47748002fb9d68846a1e934f1bd5daf0b", size = 502725, upload-time = "2025-10-13T18:56:19.424Z" }, + { url = "https://files.pythonhosted.org/packages/d8/e3/da3c36398bfb730b96248d006cabaceed87e401ff56edafb2a978293e228/cftime-1.6.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:e62e9f2943e014c5ef583245bf2e878398af131c97e64f8cd47c1d7baef5c4e2", size = 485445, upload-time = "2025-10-13T18:56:20.853Z" }, + { url = "https://files.pythonhosted.org/packages/32/93/b05939e5abd14bd1ab69538bbe374b4ee2a15467b189ff895e9a8cdaddf6/cftime-1.6.5-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7da5fdaa4360d8cb89b71b8ded9314f2246aa34581e8105c94ad58d6102d9e4f", size = 1584434, upload-time = "2025-10-13T19:39:17.084Z" }, + { url = "https://files.pythonhosted.org/packages/7f/89/648397f9936e0b330999c4e776ebf296ec3c6a65f9901687dbca4ab820da/cftime-1.6.5-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bff865b4ea4304f2744a1ad2b8149b8328b321dd7a2b9746ef926d229bd7cd49", size = 1609812, upload-time = "2025-10-13T18:56:21.971Z" }, + { url = "https://files.pythonhosted.org/packages/e7/0f/901b4835aa67ad3e915605d4e01d0af80a44b114eefab74ae33de6d36933/cftime-1.6.5-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e552c5d1c8a58f25af7521e49237db7ca52ed2953e974fe9f7c4491e95fdd36c", size = 1669768, upload-time = "2025-10-13T18:56:24.027Z" }, + { url = "https://files.pythonhosted.org/packages/22/d5/e605e4b28363e7a9ae98ed12cabbda5b155b6009270e6a231d8f10182a17/cftime-1.6.5-cp314-cp314-win_amd64.whl", hash = "sha256:e645b095dc50a38ac454b7e7f0742f639e7d7f6b108ad329358544a6ff8c9ba2", size = 463818, upload-time = "2025-10-13T18:56:25.376Z" }, + { url = "https://files.pythonhosted.org/packages/3d/89/a8f85ae697ff10206ec401c2621f5ca9f327554f586d62f244739ceeb347/cftime-1.6.5-cp314-cp314-win_arm64.whl", hash = "sha256:b9044d7ac82d3d8af189df1032fdc871bbd3f3dd41a6ec79edceb5029b71e6e0", size = 459862, upload-time = "2026-01-02T20:45:02.625Z" }, + { url = "https://files.pythonhosted.org/packages/ab/05/7410e12fd03a0c52717e74e6a1b49958810807dda212e23b65d43ea99676/cftime-1.6.5-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:9ef56460cb0576e1a9161e1428c9e1a633f809a23fa9d598f313748c1ae5064e", size = 533781, upload-time = "2026-01-02T20:45:04.818Z" }, + { url = "https://files.pythonhosted.org/packages/44/ba/10e3546426d3ed9f9cc82e4a99836bb6fac1642c7830f7bdd0ac1c3f0805/cftime-1.6.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:4f4873d38b10032f9f3111c547a1d485519ae64eee6a7a2d091f1f8b08e1ba50", size = 515218, upload-time = "2026-01-02T20:45:06.788Z" }, + { url = "https://files.pythonhosted.org/packages/bd/68/efa11eae867749e921bfec6a865afdba8166e96188112dde70bb8bb49254/cftime-1.6.5-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ccce0f4c9d3f38dd948a117e578b50d0e0db11e2ca9435fb358fd524813e4b61", size = 1579932, upload-time = "2026-01-02T20:45:11.194Z" }, + { url = "https://files.pythonhosted.org/packages/9d/6c/0971e602c1390a423e6621dfbad9f1d375186bdaf9c9c7f75e06f1fbf355/cftime-1.6.5-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:19cbfc5152fb0b34ce03acf9668229af388d7baa63a78f936239cb011ccbe6b1", size = 1555894, upload-time = "2026-01-02T20:45:16.351Z" }, + { url = "https://files.pythonhosted.org/packages/ad/fc/8475a15b7c3209a4a68b563dfc5e01ce74f2d8b9822372c3d30c68ab7f39/cftime-1.6.5-cp314-cp314t-win_amd64.whl", hash = "sha256:4470cd5ef3c2514566f53efbcbb64dd924fa0584637d90285b2f983bd4ee7d97", size = 513027, upload-time = "2026-01-02T20:45:20.023Z" }, + { url = "https://files.pythonhosted.org/packages/f7/80/4ecbda8318fbf40ad4e005a4a93aebba69e81382e5b4c6086251cd5d0ee8/cftime-1.6.5-cp314-cp314t-win_arm64.whl", hash = "sha256:034c15a67144a0a5590ef150c99f844897618b148b87131ed34fda7072614662", size = 469065, upload-time = "2026-01-02T20:45:23.398Z" }, +] + [[package]] name = "charset-normalizer" version = "3.4.2" @@ -502,6 +560,22 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/20/94/c5790835a017658cbfabd07f3bfb549140c3ac458cfc196323996b10095a/charset_normalizer-3.4.2-py3-none-any.whl", hash = "sha256:7f56930ab0abd1c45cd15be65cc741c28b1c9a34876ce8c17a2fa107810c0af0", size = 52626, upload-time = "2025-05-02T08:34:40.053Z" }, ] +[[package]] +name = "chemfiles" +version = "0.10.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c2/51/35538663b6384add778945735478da66b7c3095649654325d001922f30f8/chemfiles-0.10.4.tar.gz", hash = "sha256:f9e5ece3fcc8b63fdc2708d4ecc2ba5862ae2ab6790447bffc10c1b34ef2f445", size = 3575412, upload-time = "2023-05-23T10:49:17.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8b/0d/e5a214dddec845c425cda2cb2273a95b2c5f77be9404d02c4f48b4e6992b/chemfiles-0.10.4-1-py2.py3-none-win_amd64.whl", hash = "sha256:5c1b50a7fd56d014f930e38a838c92098bd047a3e989ba4b89ff657c6d16e38a", size = 1129225, upload-time = "2023-05-24T15:02:46.683Z" }, + { url = "https://files.pythonhosted.org/packages/84/0e/409d1fe39dc24f3ac47dd384e78462fc4eb0435a169afe5b488cf6ded39b/chemfiles-0.10.4-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:10a4e641605db56321316310f620746db350691d7c9edc433fe2a65984e2278b", size = 1497588, upload-time = "2023-05-23T10:49:04.561Z" }, + { url = "https://files.pythonhosted.org/packages/78/5f/d7d7347db0d1a92577aa27d9412adea002295263d52cca57ff14c92cde56/chemfiles-0.10.4-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:626725b0ea907d995cbbba99df1d19c474f8ebecdea8d0d390b7f3eaf2c91039", size = 1350827, upload-time = "2023-05-23T10:49:07.125Z" }, + { url = "https://files.pythonhosted.org/packages/3a/d5/beb71f372e650ba75e3eac246a17daa09a08aeed46580b62af35234d01f2/chemfiles-0.10.4-py2.py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4dbf6fa7ad5b2a1ad1415fbca905ce3a02c71cc2aa7fbce18a2b7d13c01a3664", size = 1751189, upload-time = "2023-05-23T10:49:10.237Z" }, + { url = "https://files.pythonhosted.org/packages/50/4c/380de5755146e27236cdecf02b7fe5da4c1f3786716baee5b3a245026acb/chemfiles-0.10.4-py2.py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ef8f2b9fa65885658088180bb33971d1337bc8542220c710d1f6f3c1a6d661d4", size = 1632279, upload-time = "2023-05-23T10:49:12.365Z" }, +] + [[package]] name = "click" version = "8.2.1" @@ -594,7 +668,8 @@ name = "cryptography" version = "45.0.7" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.14' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine == 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", "python_full_version >= '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'win32'", ] dependencies = [ @@ -639,9 +714,12 @@ name = "cryptography" version = "46.0.0" source = { registry = "https://pypi.org/simple" } resolution-markers = [ - "python_full_version >= '3.12' and python_full_version < '3.14' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", - "python_full_version >= '3.12' and platform_python_implementation == 'PyPy' and sys_platform == 'win32'", - "python_full_version < '3.12' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and platform_machine == 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and platform_machine != 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and platform_machine == 'ARM64' and platform_python_implementation == 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and platform_machine != 'ARM64' and platform_python_implementation == 'PyPy' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine == 'ARM64' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine != 'ARM64' and sys_platform == 'win32'", "python_full_version >= '3.12' and python_full_version < '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'win32'", "python_full_version >= '3.12' and platform_python_implementation == 'PyPy' and sys_platform != 'win32'", "python_full_version < '3.12' and sys_platform != 'win32'", @@ -1048,6 +1126,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3a/5d/b645a1e7c71ba562cf31987ee7499f603b6b49f67ccab521b3b600f53a1e/gemmi-0.7.1-cp313-cp313-win_amd64.whl", hash = "sha256:402a71c935cab167ac6a7a29045e47a972388ef6f62fa3f477d8b0241fe53d4e", size = 1928436, upload-time = "2025-03-24T19:20:03.183Z" }, ] +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, +] + +[[package]] +name = "gitpython" +version = "3.1.50" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/f6/354ae6491228b5eb40e10d89c4d13c651fe1cf7556e35ebdded50cff57ce/gitpython-3.1.50.tar.gz", hash = "sha256:80da2d12504d52e1f998772dc5baf6e553f8d2fcfe1fcc226c9d9a2ee3372dcc", size = 219798, upload-time = "2026-05-06T04:01:26.571Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/7a/1c6e3562dfd8950adbb11ffbc65d21e7c89d01a6e4f137fa981056de25c5/gitpython-3.1.50-py3-none-any.whl", hash = "sha256:d352abe2908d07355014abdd21ddf798c2a961469239afec4962e9da884858f9", size = 212507, upload-time = "2026-05-06T04:01:23.799Z" }, +] + +[[package]] +name = "graph2mat" +version = "0.0.13" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ase" }, + { name = "numpy" }, + { name = "scipy" }, + { name = "sisl", extra = ["viz"] }, + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9c/37/bf1deadade49d409d17c549f50b76f3f8de0c810817b49005dfc966c9f89/graph2mat-0.0.13.tar.gz", hash = "sha256:23f251ec044e0cc79c126c3cc687ada17708f316265d69f75d3ab76a14591a03", size = 1251793, upload-time = "2025-10-14T11:43:29.014Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/8b/7ebe6acdbd2bd8623a7014d411d4b447d8d6fa3994bfa16fae2b9fa39787/graph2mat-0.0.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bfb9c25cb2aea6edd8f365355c81589558bcd7a6f734b626842995a6449ebd0e", size = 363873, upload-time = "2025-10-14T11:43:19.374Z" }, + { url = "https://files.pythonhosted.org/packages/54/69/d0916760e124f23ecd407c58428b3b9f00709897008cfdf5602b1525f9bd/graph2mat-0.0.13-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:350bbd488c75ffdf4821ad98cbbc1f05db1b0fd80f67dc0aa75028355e501ad7", size = 450680, upload-time = "2025-10-14T11:43:20.272Z" }, + { url = "https://files.pythonhosted.org/packages/a0/ec/6766f2b92138563a73678ce96c61a5decf2e12cd5cedbe0f390c43a682b2/graph2mat-0.0.13-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:da7d7bd08d65506957c1d369b2640d84c37894af7ed8ba78a1bcd29d07671549", size = 365265, upload-time = "2025-10-14T11:43:21.513Z" }, + { url = "https://files.pythonhosted.org/packages/e1/86/a990e6340b06f366180007bdfbeadb3868b935834c159c2e9878f70a80a3/graph2mat-0.0.13-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6a5eb39e897a39f53cc510e992afb1becdbe52a23cf9c487ebdc1164163ab752", size = 444182, upload-time = "2025-10-14T11:43:22.776Z" }, + { url = "https://files.pythonhosted.org/packages/5e/51/a28401d4be00822f81557d47c14b8f2c344a866865081b336d24d2bc5c4f/graph2mat-0.0.13-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e010c75262b4fc5ff7dd1bf90b22d67c4ea4bc0d83a8a3c3507d4a8d2e3e79e0", size = 363479, upload-time = "2025-10-14T11:43:23.791Z" }, + { url = "https://files.pythonhosted.org/packages/ee/07/9c857c0d3ca21a553b4e80869f1145469f5ffa0b34d775a095a3fec81d21/graph2mat-0.0.13-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:12f7697491a7b526485e4c136b0b5634f96b755495aae77b468a930dfbaac239", size = 445299, upload-time = "2025-10-14T11:43:24.91Z" }, +] + [[package]] name = "gunicorn" version = "25.1.0" @@ -1103,6 +1226,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442, upload-time = "2024-09-15T18:07:37.964Z" }, ] +[[package]] +name = "imageio" +version = "2.37.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pillow" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/84/93bcd1300216ea50811cee96873b84a1bebf8d0489ffaf7f2a3756bab866/imageio-2.37.3.tar.gz", hash = "sha256:bbb37efbfc4c400fcd534b367b91fcd66d5da639aaa138034431a1c5e0a41451", size = 389673, upload-time = "2026-03-09T11:31:12.573Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/49/fa/391e437a34e55095173dca5f24070d89cbc233ff85bf1c29c93248c6588d/imageio-2.37.3-py3-none-any.whl", hash = "sha256:46f5bb8522cd421c0f5ae104d8268f569d856b29eb1a13b92829d1970f32c9f0", size = 317646, upload-time = "2026-03-09T11:31:10.771Z" }, +] + [[package]] name = "ipykernel" version = "6.29.5" @@ -1374,6 +1510,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/40/23569737873cc9637fd488606347e9dd92b9fa37ba4fcda1f98ee5219a97/latexcodec-3.0.1-py3-none-any.whl", hash = "sha256:a9eb8200bff693f0437a69581f7579eb6bca25c4193515c09900ce76451e452e", size = 18532, upload-time = "2025-06-17T18:47:30.726Z" }, ] +[[package]] +name = "lazy-loader" +version = "0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "packaging" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/49/ac/21a1f8aa3777f5658576777ea76bfb124b702c520bbe90edf4ae9915eafa/lazy_loader-0.5.tar.gz", hash = "sha256:717f9179a0dbed357012ddad50a5ad3d5e4d9a0b8712680d4e687f5e6e6ed9b3", size = 15294, upload-time = "2026-03-06T15:45:09.054Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/a1/8d812e53a5da1687abb10445275d41a8b13adb781bbf7196ddbcf8d88505/lazy_loader-0.5-py3-none-any.whl", hash = "sha256:ab0ea149e9c554d4ffeeb21105ac60bed7f3b4fd69b1d2360a4add51b170b005", size = 8044, upload-time = "2026-03-06T15:45:07.668Z" }, +] + [[package]] name = "lemat-rho" version = "0.1.0" @@ -1381,28 +1529,40 @@ source = { virtual = "." } dependencies = [ { name = "ase" }, { name = "atomate2" }, + { name = "chemfiles" }, { name = "e3nn" }, { name = "fireworks" }, + { name = "graph2mat" }, { name = "ipykernel" }, { name = "lz4" }, { name = "material-hasher" }, - { name = "pandas" }, + { name = "metatensor" }, + { name = "numpy" }, { name = "pyarrow" }, + { name = "python-dotenv" }, { name = "scipy" }, + { name = "torch" }, + { name = "wandb" }, ] [package.metadata] requires-dist = [ { name = "ase", specifier = ">=3.25.0" }, { name = "atomate2" }, - { name = "e3nn", specifier = ">=0.6.0" }, + { name = "chemfiles", specifier = ">=0.10.4" }, + { name = "e3nn", specifier = ">=0.5.0" }, { name = "fireworks" }, + { name = "graph2mat", specifier = ">=0.0.13" }, { name = "ipykernel", specifier = ">=6.29.5" }, - { name = "lz4", specifier = ">=4.4.5" }, + { name = "lz4", specifier = ">=4.0.0" }, { name = "material-hasher", git = "https://github.com/LeMaterial/lematerial-hasher" }, - { name = "pandas", specifier = ">=2.3.0" }, - { name = "pyarrow", specifier = ">=20.0.0" }, - { name = "scipy", specifier = ">=1.16.0" }, + { name = "metatensor", specifier = ">=0.2.0" }, + { name = "numpy", specifier = ">=1.24" }, + { name = "pyarrow", specifier = ">=14.0.0" }, + { name = "python-dotenv", specifier = ">=1.0.0" }, + { name = "scipy", specifier = ">=1.10.0" }, + { name = "torch", specifier = ">=2.0" }, + { name = "wandb", specifier = ">=0.16.0" }, ] [[package]] @@ -1521,6 +1681,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/df/e6ed9ae87af6941300f111b7cb1b69cdc5f605bb86e7815f5cc3d4043d22/maggma-0.72.1-py3-none-any.whl", hash = "sha256:5aa894a3a2c0cef6629bb122b8025125af2099d09b5b284c9adfd75d9b56dfb1", size = 123654, upload-time = "2026-02-11T18:52:44.788Z" }, ] +[[package]] +name = "markdown-it-py" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mdurl" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/06/ff/7841249c247aa650a76b9ee4bbaeae59370dc8bfd2f6c01f3630c35eb134/markdown_it_py-4.2.0.tar.gz", hash = "sha256:04a21681d6fbb623de53f6f364d352309d4094dd4194040a10fd51833e418d49", size = 82454, upload-time = "2026-05-07T12:08:28.36Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/81/4da04ced5a082363ecfa159c010d200ecbd959ae410c10c0264a38cac0f5/markdown_it_py-4.2.0-py3-none-any.whl", hash = "sha256:9f7ebbcd14fe59494226453aed97c1070d83f8d24b6fc3a3bcf9a38092641c4a", size = 91687, upload-time = "2026-05-07T12:08:27.182Z" }, +] + [[package]] name = "markupsafe" version = "3.0.2" @@ -1639,6 +1811,70 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, ] +[[package]] +name = "mdurl" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d6/54/cfe61301667036ec958cb99bd3efefba235e65cdeb9c84d24a8293ba1d90/mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba", size = 8729, upload-time = "2022-08-14T12:40:10.846Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, +] + +[[package]] +name = "metatensor" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "metatensor-core" }, + { name = "metatensor-learn" }, + { name = "metatensor-operations" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/58/172e96ccdca4d8d572579adc69b593dad79b74497c116ed86979257a5cbd/metatensor-0.2.0.tar.gz", hash = "sha256:ce3f8a34796d2aaa7e74b2d1392f64a05e85d1ca3e3878c1e9259e6a6a7a8138", size = 5373, upload-time = "2024-01-26T17:27:15.203Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f7/28/fd3f02ccb23764af794e953262127a7f2aed35073f460da6f279fe1c2b15/metatensor-0.2.0-py3-none-any.whl", hash = "sha256:60008fee73f49b349350d9d93dec63ea4e1cf30beceae17d543561d69a7ac393", size = 3702, upload-time = "2024-01-26T17:26:59.518Z" }, +] + +[[package]] +name = "metatensor-core" +version = "0.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/34/d5/18f05f73a0af0517dbbf441e673abf88bccfec6a92a1beeebbc9df9d5ed9/metatensor_core-0.2.0.tar.gz", hash = "sha256:30200451eb70e635fdef5dfd46476d0303b1757b1e34c23f9c9e568c9d188545", size = 177741, upload-time = "2026-05-13T15:45:51.837Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/99/4a81ad15c63b82be70e8e9ca1ae95b31b7c91d512b684c8a26fb0671a746/metatensor_core-0.2.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c5e82760244c7233c41547d6c015f38caf7f3af589e0a7f827cad4a0c0ef0bbf", size = 549924, upload-time = "2026-05-13T15:45:08.494Z" }, + { url = "https://files.pythonhosted.org/packages/f0/11/8cd0fea97a5be6793596f573bb2fabf5dfd00a67884f9c77e6c7331c3921/metatensor_core-0.2.0-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:286f477f96520c046dff35dbc3a40ac3cfdef540e1c7bc071e91769f68dbb8f8", size = 582626, upload-time = "2026-05-13T15:45:18.982Z" }, + { url = "https://files.pythonhosted.org/packages/b4/09/91e7f49401597f0858087a3e603f98bb78d900895510b799fa445e1a4a8e/metatensor_core-0.2.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f0529b6d3966fff6ad85e988443c2acf22d0251f52be38d4dce6fa4d617c0e81", size = 594606, upload-time = "2026-05-13T15:45:33.145Z" }, + { url = "https://files.pythonhosted.org/packages/bc/50/e090f6a2c56a6c822bac818ca5d900568a17df8ea6a2d1bf9f8d8cde9fc0/metatensor_core-0.2.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dbdf693cdb0436736e8d678e2d45dec5f8e47df18c7a4f775eb546c0106fb867", size = 634966, upload-time = "2026-05-13T15:45:39.684Z" }, + { url = "https://files.pythonhosted.org/packages/55/65/84df97b3922d50954644b06397e337e4a52da98ddd92f52a1532329d1378/metatensor_core-0.2.0-py3-none-win_amd64.whl", hash = "sha256:2b7dfc59c920b1d06dbebd2e7afa0a2395ea1ef01e437ad7a0e4d213f2034ce1", size = 533600, upload-time = "2026-05-13T15:45:44.907Z" }, +] + +[[package]] +name = "metatensor-learn" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "metatensor-core" }, + { name = "metatensor-operations" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/bd/0fd1901b44635a24f40528a6244b5889143747ddeb841ae0201255c1f22e/metatensor_learn-0.5.0.tar.gz", hash = "sha256:0b1d30ed217d70de7851ed1d48421515d9c6a1be7f50d9b1b43f92a689be51d0", size = 25221, upload-time = "2026-05-13T15:45:54.582Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/85/f8e2061c58cf4ea22681be48f5aecf0074abd9717fcb8f05dd3ea6e370fc/metatensor_learn-0.5.0-py3-none-any.whl", hash = "sha256:ad8863dac144f03c9ca80ec625c9e35b87ceb82438a0a80c0bf14e9dcc1b607c", size = 32888, upload-time = "2026-05-13T15:45:49.25Z" }, +] + +[[package]] +name = "metatensor-operations" +version = "0.5.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "metatensor-core" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3a/98/83e132e8aca5bc05ffaffd342566ab4abd8e7bb579de6df1fde8b8602abb/metatensor_operations-0.5.0.tar.gz", hash = "sha256:e1cb0a8c358842e94ac3680fa9ec6f7a006cb519b6950ed1bb7001a209087cfc", size = 57735, upload-time = "2026-05-13T15:45:53.568Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/31/18d10b7d6d2ef5829a33c52cb0148730a951f0b3ad13aac5c4fae510ccfd/metatensor_operations-0.5.0-py3-none-any.whl", hash = "sha256:9536562c9e02a5c723fc118be671e8ff37e8e69caf2dc4a2bd97fca5271ec510", size = 79354, upload-time = "2026-05-13T15:45:47.855Z" }, +] + [[package]] name = "mongomock" version = "4.3.0" @@ -1851,6 +2087,62 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, ] +[[package]] +name = "netcdf4" +version = "1.7.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and platform_machine == 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and platform_machine == 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and platform_machine == 'ARM64' and platform_python_implementation == 'PyPy' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine == 'ARM64' and sys_platform == 'win32'", +] +dependencies = [ + { name = "certifi", marker = "platform_machine == 'ARM64' and sys_platform == 'win32'" }, + { name = "cftime", marker = "platform_machine == 'ARM64' and sys_platform == 'win32'" }, + { name = "numpy", marker = "platform_machine == 'ARM64' and sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/76/7bc801796dee752c1ce9cd6935564a6ee79d5c9d9ef9192f57b156495a35/netcdf4-1.7.3.tar.gz", hash = "sha256:83f122fc3415e92b1d4904fd6a0898468b5404c09432c34beb6b16c533884673", size = 836095, upload-time = "2025-10-13T18:38:00.76Z" } + +[[package]] +name = "netcdf4" +version = "1.7.4" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and platform_machine != 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and platform_machine != 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and platform_machine != 'ARM64' and platform_python_implementation == 'PyPy' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine != 'ARM64' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'win32'", + "python_full_version >= '3.12' and platform_python_implementation == 'PyPy' and sys_platform != 'win32'", + "python_full_version < '3.12' and sys_platform != 'win32'", +] +dependencies = [ + { name = "certifi", marker = "platform_machine != 'ARM64' or sys_platform != 'win32'" }, + { name = "cftime", marker = "platform_machine != 'ARM64' or sys_platform != 'win32'" }, + { name = "numpy", marker = "platform_machine != 'ARM64' or sys_platform != 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/34/b6/0370bb3af66a12098da06dc5843f3b349b7c83ccbdf7306e7afa6248b533/netcdf4-1.7.4.tar.gz", hash = "sha256:cdbfdc92d6f4d7192ca8506c9b3d4c1d9892969ff28d8e8e1fc97ca08bf12164", size = 838352, upload-time = "2026-01-05T02:27:38.593Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/de/38ed7e1956943d28e8ea74161e97c3a00fb98d6d08943b4fd21bae32c240/netcdf4-1.7.4-cp311-abi3-macosx_13_0_x86_64.whl", hash = "sha256:dec70e809cc65b04ebe95113ee9c85ba46a51c3a37c058d2b2b0cadc4d3052d8", size = 23427499, upload-time = "2026-01-05T02:27:06.568Z" }, + { url = "https://files.pythonhosted.org/packages/e5/70/2f73c133b71709c412bc81d8b721e28dc6237ba9d7dad861b7bfbb70408a/netcdf4-1.7.4-cp311-abi3-macosx_14_0_arm64.whl", hash = "sha256:75cf59100f0775bc4d6b9d4aca7cbabd12e2b8cf3b9a4fb16d810b92743a315a", size = 22847667, upload-time = "2026-01-05T02:27:09.421Z" }, + { url = "https://files.pythonhosted.org/packages/77/ce/43a3c0c41a6e2e940d87feea79d29aa88302211ac122604838f8a5a48de6/netcdf4-1.7.4-cp311-abi3-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ddfc7e9d261125c74708119440c85ea288b5fee41db676d2ba1ce9be11f96932", size = 10274769, upload-time = "2026-01-05T21:31:19.243Z" }, + { url = "https://files.pythonhosted.org/packages/7b/7a/a8d32501bb95ecff342004a674720164f95ad616f269450b3bc13dc88ae3/netcdf4-1.7.4-cp311-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a72c9f58767779ec14cb7451c3b56bdd8fdc027a792fac2062b14e090c5617f3", size = 10123122, upload-time = "2026-01-05T21:31:22.773Z" }, + { url = "https://files.pythonhosted.org/packages/18/68/e89b4fa9242e59326c849c39ce0f49eb68499603c639405a8449900a4f15/netcdf4-1.7.4-cp311-abi3-win_amd64.whl", hash = "sha256:9476e1f23161ae5159cd1548c50c8a37922e77d76583e247133f256ef7b825fc", size = 21299637, upload-time = "2026-01-05T02:27:11.856Z" }, + { url = "https://files.pythonhosted.org/packages/6c/fc/edd41a3607241027aa4533e7f18e0cd647e74dde10a63274c65350f59967/netcdf4-1.7.4-cp311-abi3-win_arm64.whl", hash = "sha256:876ad9d58f09c98741c066c726164c45a098a58fb90e5fac9e74de4bb8a793fd", size = 2386377, upload-time = "2026-01-05T02:27:13.808Z" }, + { url = "https://files.pythonhosted.org/packages/f1/3e/1e83534ba68459bc5ae39df46fa71003984df58aabf31f7dcd6e22ecddb0/netcdf4-1.7.4-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:56688c03444fffe0d0c7512cb45245e650389cd841c955b30e4552fa681c4cd9", size = 10519821, upload-time = "2026-01-05T02:27:15.413Z" }, + { url = "https://files.pythonhosted.org/packages/c0/8c/a15d6fe97f81d6d5202b17838a9a298b5955b3e9971e20609195112829b5/netcdf4-1.7.4-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ecf471ba8a6ddb2200121949bedfa0095db228822f38227d5da680694a38358", size = 10371133, upload-time = "2026-01-05T02:27:17.224Z" }, + { url = "https://files.pythonhosted.org/packages/d8/2b/684b15dd4791f8be295b2f6fa97377bbc07a768478a63b7d3c4951712e36/netcdf4-1.7.4-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a5841de0735e8e4875b367c668e81d334287858d64dd9f3e3e2261e808c84922", size = 10395635, upload-time = "2026-01-05T02:27:19.655Z" }, + { url = "https://files.pythonhosted.org/packages/37/dc/44d21524cf1b1c64254f92e22395a7a10f70c18f3a13a18ac9db258760f7/netcdf4-1.7.4-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:86fac03a8c5b250d57866e7d98918a64742e4b0de1681c5c86bac5726bab8aee", size = 10237725, upload-time = "2026-01-05T02:27:22.298Z" }, + { url = "https://files.pythonhosted.org/packages/d4/9d/c3ddf54296ad8f18f02f77f23452bdb0971aece1b87e84bab9d734bf72cc/netcdf4-1.7.4-cp314-cp314t-macosx_13_0_x86_64.whl", hash = "sha256:ad083d260301b5add74b1669c75ab0df03bdf986decfcc092cb45eec2615b5f1", size = 23515258, upload-time = "2026-01-05T02:27:24.837Z" }, + { url = "https://files.pythonhosted.org/packages/dd/44/bc0346e995d436d03fab682b7fbd2a9adcf0db6a05790b8f24853bf08170/netcdf4-1.7.4-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:7f22014092cc9da3f056b0368e2e38c42afd5725c87ad4843eb2f467e16dd4f6", size = 22910171, upload-time = "2026-01-05T02:27:27.166Z" }, + { url = "https://files.pythonhosted.org/packages/30/6b/f9bc3f43c55e2dac72ee9f98d77860789bdd5d50c29adf164a6bdb303078/netcdf4-1.7.4-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:224a15434c165a5e0225e5831f591edf62533044b1ce62fdfee815195bbd077d", size = 10567579, upload-time = "2026-01-05T02:27:29.382Z" }, + { url = "https://files.pythonhosted.org/packages/6d/d5/e7685c66b7f011c73cd746127f986358a26c642a4e4a1aa5ab51481b6586/netcdf4-1.7.4-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:31a2318305de6831a18df25ad0df9f03b6d68666af0356d4f6057d66c02ffeb6", size = 10255032, upload-time = "2026-01-05T02:27:31.744Z" }, + { url = "https://files.pythonhosted.org/packages/a6/14/7506738bb6c8bc373b01e5af8f3b727f83f4f496c6b108490ea2609dc2cf/netcdf4-1.7.4-cp314-cp314t-win_amd64.whl", hash = "sha256:6c4a0aa9446c3a616ef3be015b629dc6173643f8b09546de26a4e40e272cd1ed", size = 22289653, upload-time = "2026-01-05T02:27:34.294Z" }, + { url = "https://files.pythonhosted.org/packages/af/2e/39d5e9179c543f2e6e149a65908f83afd9b6d64379a90789b323111761db/netcdf4-1.7.4-cp314-cp314t-win_arm64.whl", hash = "sha256:034220887d48da032cb2db5958f69759dbb04eb33e279ec6390571d4aea734fe", size = 2531682, upload-time = "2026-01-05T02:27:37.062Z" }, +] + [[package]] name = "networkx" version = "3.5" @@ -1860,6 +2152,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/eb/8d/776adee7bbf76365fdd7f2552710282c79a4ead5d2a46408c9043a2b70ba/networkx-3.5-py3-none-any.whl", hash = "sha256:0030d386a9a06dee3565298b4a734b68589749a544acbb6c412dc9e2489ec6ec", size = 2034406, upload-time = "2025-05-29T11:35:04.961Z" }, ] +[[package]] +name = "nodify" +version = "0.0.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/70/b4/d1a3da7364b94ea658aa257a248e817296019273d99c3773eb88768162b9/nodify-0.0.12.tar.gz", hash = "sha256:0905e42279f5958ed76cc67ced1c5e1cbc6c3e3e88763b0c838f7b7e0fba828a", size = 6538789, upload-time = "2025-10-09T23:24:57.939Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0f/de/c682cbbd8886eda756364be9e4e156a9906711a7b535a6691346e2a69061/nodify-0.0.12-py3-none-any.whl", hash = "sha256:8fae737a644a300fea9b68d4e296375da6cfb74b75dff84ea17aa197888473e6", size = 6610258, upload-time = "2025-10-09T23:24:56.437Z" }, +] + [[package]] name = "numba" version = "0.61.2" @@ -2227,6 +2528,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650, upload-time = "2024-04-05T09:43:53.299Z" }, ] +[[package]] +name = "pathos" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "dill" }, + { name = "multiprocess" }, + { name = "pox" }, + { name = "ppft" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/be/99/7fcb91495e40735958a576b9bde930cc402d594e9ad5277bdc9b6326e1c8/pathos-0.3.2.tar.gz", hash = "sha256:4f2a42bc1e10ccf0fe71961e7145fc1437018b6b21bd93b2446abc3983e49a7a", size = 166506, upload-time = "2024-01-28T19:11:27.603Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/7f/cea34872c000d17972dad998575d14656d7c6bcf1a08a8d66d73c1ef2cca/pathos-0.3.2-py3-none-any.whl", hash = "sha256:d669275e6eb4b3fbcd2846d7a6d1bba315fe23add0c614445ba1408d8b38bafe", size = 82075, upload-time = "2024-01-28T19:11:25.56Z" }, +] + [[package]] name = "pexpect" version = "4.9.0" @@ -2329,6 +2645,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ed/20/f2b7ac96a91cc5f70d81320adad24cc41bf52013508d649b1481db225780/plotly-6.2.0-py3-none-any.whl", hash = "sha256:32c444d4c940887219cb80738317040363deefdfee4f354498cc0b6dab8978bd", size = 9635469, upload-time = "2025-06-26T16:20:40.76Z" }, ] +[[package]] +name = "pox" +version = "0.3.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/44/58/4385741dea1d74fe9dfed7ff42975266634ef8000f2c8e96717079c916b1/pox-0.3.7.tar.gz", hash = "sha256:0652f6f2103fe6d4ba638beb6fa8d3e8a68fd44bcb63315c614118515bcc3afb", size = 119442, upload-time = "2026-01-19T02:09:12.573Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/ac/4d5f104edf2aae2fec85567ec1d1969010de8124c5c45514f25e14900b65/pox-0.3.7-py3-none-any.whl", hash = "sha256:82a495249d13371314c1a5b5626a115e067ef5215d49530bf5efa37fbc25b56a", size = 29402, upload-time = "2026-01-19T02:09:11.024Z" }, +] + +[[package]] +name = "ppft" +version = "1.7.8" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/d2/281aa3466e948283d51b83238fb456f65e14f8ade5f8627822578cd2708f/ppft-1.7.8.tar.gz", hash = "sha256:5f696d4f397ae9b0af39b1faffb31957c51dfbc5a3815856472d4f4e872937ee", size = 136349, upload-time = "2026-01-19T03:03:13.439Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/e1/d1b380af6443e7c33aeb40617ebdc17c39dc30095235643cc518e3908203/ppft-1.7.8-py3-none-any.whl", hash = "sha256:d3e0e395215b14afc3dd5adfc032ccecfda2d4ed50dc7ded076cd1d215442843", size = 56759, upload-time = "2026-01-19T03:03:11.896Z" }, +] + [[package]] name = "prompt-toolkit" version = "3.0.51" @@ -2414,6 +2748,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cc/35/cc0aaecf278bb4575b8555f2b137de5ab821595ddae9da9d3cd1da4072c7/propcache-0.3.2-py3-none-any.whl", hash = "sha256:98f1ec44fb675f5052cccc8e609c46ed23a35a1cfd18545ad4e29002d858a43f", size = 12663, upload-time = "2025-06-09T22:56:04.484Z" }, ] +[[package]] +name = "protobuf" +version = "7.35.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/60/fd/5b1491d9e4b586d621c54f4c36b888714164b6875f8d6afa3f9072906a51/protobuf-7.35.0.tar.gz", hash = "sha256:a2efd84605f41e559f1881b0912b44099d0a2ac9bf46b3474823f10fb393b0e6", size = 458677, upload-time = "2026-05-19T23:02:29.197Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl", hash = "sha256:66be6c513931c794fa92c080ffee41671390da3d79da219cf9c0c0907f035dda", size = 433225, upload-time = "2026-05-19T23:02:19.884Z" }, + { url = "https://files.pythonhosted.org/packages/8b/39/1c76c2da93f3c507e958e0aecee2391cc44d4625de6c728bbc555195b5a8/protobuf-7.35.0-cp310-abi3-manylinux2014_aarch64.whl", hash = "sha256:fcbe42a4ac09d3ec9c987ddfcd956afd0b15f1ff613bd8371bde9405ffd5c8e5", size = 328847, upload-time = "2026-05-19T23:02:22.3Z" }, + { url = "https://files.pythonhosted.org/packages/91/1a/39f7ce90a238c1a987a4d81ec26379e02ca0aff367de68e4a1fa474215b9/protobuf-7.35.0-cp310-abi3-manylinux2014_s390x.whl", hash = "sha256:4cbf5cc286130e06a6c9bbefac442431173906dfcc979712183d4adcc01b37ee", size = 344030, upload-time = "2026-05-19T23:02:23.591Z" }, + { url = "https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl", hash = "sha256:6c0f98f10c8a05ea30f8993dfef2de093d27b490fdae78bb60c8343795d55011", size = 327130, upload-time = "2026-05-19T23:02:24.637Z" }, + { url = "https://files.pythonhosted.org/packages/8e/e5/e46adb0badc388bfb84877a5f9f026aff63f60e611016cf64dbe77e05446/protobuf-7.35.0-cp310-abi3-win32.whl", hash = "sha256:4c4617b83ade0e279d1d2bfe04025a1adb87f9ed657de038620dc0ff959357f6", size = 428946, upload-time = "2026-05-19T23:02:25.741Z" }, + { url = "https://files.pythonhosted.org/packages/a7/ab/547fbd9e16d879dd13c167478f8ae0a83a428008ca07a5e06acdc23ad473/protobuf-7.35.0-cp310-abi3-win_amd64.whl", hash = "sha256:f05bcadf9a2a6b8dda047007075135fb7d08c73d9177aabc067e1be46881a201", size = 439996, upload-time = "2026-05-19T23:02:26.808Z" }, + { url = "https://files.pythonhosted.org/packages/b8/ef/50433d346c56657a70d27f156c7b349ac59a068b01de4eb796e747eecc43/protobuf-7.35.0-py3-none-any.whl", hash = "sha256:c13f325cf242bad135c350629eeb5d54b24228eb472fb3e2e9ebbd4c5dc20ca0", size = 171659, upload-time = "2026-05-19T23:02:27.842Z" }, +] + [[package]] name = "psutil" version = "7.0.0" @@ -2943,6 +3292,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/e4/56027c4a6b4ae70ca9de302488c5ca95ad4a39e190093d6c1a8ace08341b/requests-2.32.4-py3-none-any.whl", hash = "sha256:27babd3cda2a6d50b30443204ee89830707d396671944c998b5975b031ac2b2c", size = 64847, upload-time = "2025-06-09T16:43:05.728Z" }, ] +[[package]] +name = "rich" +version = "15.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown-it-py" }, + { name = "pygments" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/8f/0722ca900cc807c13a6a0c696dacf35430f72e0ec571c4275d2371fca3e9/rich-15.0.0.tar.gz", hash = "sha256:edd07a4824c6b40189fb7ac9bc4c52536e9780fbbfbddf6f1e2502c31b068c36", size = 230680, upload-time = "2026-04-12T08:24:00.75Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" }, +] + [[package]] name = "rpds-py" version = "0.30.0" @@ -3110,6 +3472,73 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, ] +[[package]] +name = "scikit-image" +version = "0.26.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "imageio" }, + { name = "lazy-loader" }, + { name = "networkx" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "pillow" }, + { name = "scipy" }, + { name = "tifffile", version = "2026.3.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.12'" }, + { name = "tifffile", version = "2026.5.15", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/b4/2528bb43c67d48053a7a649a9666432dc307d66ba02e3a6d5c40f46655df/scikit_image-0.26.0.tar.gz", hash = "sha256:f5f970ab04efad85c24714321fcc91613fcb64ef2a892a13167df2f3e59199fa", size = 22729739, upload-time = "2025-12-20T17:12:21.824Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/16/8a407688b607f86f81f8c649bf0d68a2a6d67375f18c2d660aba20f5b648/scikit_image-0.26.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b1ede33a0fb3731457eaf53af6361e73dd510f449dac437ab54573b26788baf0", size = 12355510, upload-time = "2025-12-20T17:10:31.628Z" }, + { url = "https://files.pythonhosted.org/packages/6b/f9/7efc088ececb6f6868fd4475e16cfafc11f242ce9ab5fc3557d78b5da0d4/scikit_image-0.26.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7af7aa331c6846bd03fa28b164c18d0c3fd419dbb888fb05e958ac4257a78fdd", size = 12056334, upload-time = "2025-12-20T17:10:34.559Z" }, + { url = "https://files.pythonhosted.org/packages/9f/1e/bc7fb91fb5ff65ef42346c8b7ee8b09b04eabf89235ab7dbfdfd96cbd1ea/scikit_image-0.26.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9ea6207d9e9d21c3f464efe733121c0504e494dbdc7728649ff3e23c3c5a4953", size = 13297768, upload-time = "2025-12-20T17:10:37.733Z" }, + { url = "https://files.pythonhosted.org/packages/a5/2a/e71c1a7d90e70da67b88ccc609bd6ae54798d5847369b15d3a8052232f9d/scikit_image-0.26.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:74aa5518ccea28121f57a95374581d3b979839adc25bb03f289b1bc9b99c58af", size = 13711217, upload-time = "2025-12-20T17:10:40.935Z" }, + { url = "https://files.pythonhosted.org/packages/d4/59/9637ee12c23726266b91296791465218973ce1ad3e4c56fc81e4d8e7d6e1/scikit_image-0.26.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d5c244656de905e195a904e36dbc18585e06ecf67d90f0482cbde63d7f9ad59d", size = 14337782, upload-time = "2025-12-20T17:10:43.452Z" }, + { url = "https://files.pythonhosted.org/packages/e7/5c/a3e1e0860f9294663f540c117e4bf83d55e5b47c281d475cc06227e88411/scikit_image-0.26.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:21a818ee6ca2f2131b9e04d8eb7637b5c18773ebe7b399ad23dcc5afaa226d2d", size = 14805997, upload-time = "2025-12-20T17:10:45.93Z" }, + { url = "https://files.pythonhosted.org/packages/d3/c6/2eeacf173da041a9e388975f54e5c49df750757fcfc3ee293cdbbae1ea0a/scikit_image-0.26.0-cp311-cp311-win_amd64.whl", hash = "sha256:9490360c8d3f9a7e85c8de87daf7c0c66507960cf4947bb9610d1751928721c7", size = 11878486, upload-time = "2025-12-20T17:10:48.246Z" }, + { url = "https://files.pythonhosted.org/packages/c3/a4/a852c4949b9058d585e762a66bf7e9a2cd3be4795cd940413dfbfbb0ce79/scikit_image-0.26.0-cp311-cp311-win_arm64.whl", hash = "sha256:0baa0108d2d027f34d748e84e592b78acc23e965a5de0e4bb03cf371de5c0581", size = 11346518, upload-time = "2025-12-20T17:10:50.575Z" }, + { url = "https://files.pythonhosted.org/packages/99/e8/e13757982264b33a1621628f86b587e9a73a13f5256dad49b19ba7dc9083/scikit_image-0.26.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d454b93a6fa770ac5ae2d33570f8e7a321bb80d29511ce4b6b78058ebe176e8c", size = 12376452, upload-time = "2025-12-20T17:10:52.796Z" }, + { url = "https://files.pythonhosted.org/packages/e3/be/f8dd17d0510f9911f9f17ba301f7455328bf13dae416560126d428de9568/scikit_image-0.26.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3409e89d66eff5734cd2b672d1c48d2759360057e714e1d92a11df82c87cba37", size = 12061567, upload-time = "2025-12-20T17:10:55.207Z" }, + { url = "https://files.pythonhosted.org/packages/b3/2b/c70120a6880579fb42b91567ad79feb4772f7be72e8d52fec403a3dde0c6/scikit_image-0.26.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4c717490cec9e276afb0438dd165b7c3072d6c416709cc0f9f5a4c1070d23a44", size = 13084214, upload-time = "2025-12-20T17:10:57.468Z" }, + { url = "https://files.pythonhosted.org/packages/f4/a2/70401a107d6d7466d64b466927e6b96fcefa99d57494b972608e2f8be50f/scikit_image-0.26.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7df650e79031634ac90b11e64a9eedaf5a5e06fcd09bcd03a34be01745744466", size = 13561683, upload-time = "2025-12-20T17:10:59.49Z" }, + { url = "https://files.pythonhosted.org/packages/13/a5/48bdfd92794c5002d664e0910a349d0a1504671ef5ad358150f21643c79a/scikit_image-0.26.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:cefd85033e66d4ea35b525bb0937d7f42d4cdcfed2d1888e1570d5ce450d3932", size = 14112147, upload-time = "2025-12-20T17:11:02.083Z" }, + { url = "https://files.pythonhosted.org/packages/ee/b5/ac71694da92f5def5953ca99f18a10fe98eac2dd0a34079389b70b4d0394/scikit_image-0.26.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3f5bf622d7c0435884e1e141ebbe4b2804e16b2dd23ae4c6183e2ea99233be70", size = 14661625, upload-time = "2025-12-20T17:11:04.528Z" }, + { url = "https://files.pythonhosted.org/packages/23/4d/a3cc1e96f080e253dad2251bfae7587cf2b7912bcd76fd43fd366ff35a87/scikit_image-0.26.0-cp312-cp312-win_amd64.whl", hash = "sha256:abed017474593cd3056ae0fe948d07d0747b27a085e92df5474f4955dd65aec0", size = 11911059, upload-time = "2025-12-20T17:11:06.61Z" }, + { url = "https://files.pythonhosted.org/packages/35/8a/d1b8055f584acc937478abf4550d122936f420352422a1a625eef2c605d8/scikit_image-0.26.0-cp312-cp312-win_arm64.whl", hash = "sha256:4d57e39ef67a95d26860c8caf9b14b8fb130f83b34c6656a77f191fa6d1d04d8", size = 11348740, upload-time = "2025-12-20T17:11:09.118Z" }, + { url = "https://files.pythonhosted.org/packages/4f/48/02357ffb2cca35640f33f2cfe054a4d6d5d7a229b88880a64f1e45c11f4e/scikit_image-0.26.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a2e852eccf41d2d322b8e60144e124802873a92b8d43a6f96331aa42888491c7", size = 12346329, upload-time = "2025-12-20T17:11:11.599Z" }, + { url = "https://files.pythonhosted.org/packages/67/b9/b792c577cea2c1e94cda83b135a656924fc57c428e8a6d302cd69aac1b60/scikit_image-0.26.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:98329aab3bc87db352b9887f64ce8cdb8e75f7c2daa19927f2e121b797b678d5", size = 12031726, upload-time = "2025-12-20T17:11:13.871Z" }, + { url = "https://files.pythonhosted.org/packages/07/a9/9564250dfd65cb20404a611016db52afc6268b2b371cd19c7538ea47580f/scikit_image-0.26.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:915bb3ba66455cf8adac00dc8fdf18a4cd29656aec7ddd38cb4dda90289a6f21", size = 13094910, upload-time = "2025-12-20T17:11:16.2Z" }, + { url = "https://files.pythonhosted.org/packages/a3/b8/0d8eeb5a9fd7d34ba84f8a55753a0a3e2b5b51b2a5a0ade648a8db4a62f7/scikit_image-0.26.0-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b36ab5e778bf50af5ff386c3ac508027dc3aaeccf2161bdf96bde6848f44d21b", size = 13660939, upload-time = "2025-12-20T17:11:18.464Z" }, + { url = "https://files.pythonhosted.org/packages/2f/d6/91d8973584d4793d4c1a847d388e34ef1218d835eeddecfc9108d735b467/scikit_image-0.26.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:09bad6a5d5949c7896c8347424c4cca899f1d11668030e5548813ab9c2865dcb", size = 14138938, upload-time = "2025-12-20T17:11:20.919Z" }, + { url = "https://files.pythonhosted.org/packages/39/9a/7e15d8dc10d6bbf212195fb39bdeb7f226c46dd53f9c63c312e111e2e175/scikit_image-0.26.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:aeb14db1ed09ad4bee4ceb9e635547a8d5f3549be67fc6c768c7f923e027e6cd", size = 14752243, upload-time = "2025-12-20T17:11:23.347Z" }, + { url = "https://files.pythonhosted.org/packages/8f/58/2b11b933097bc427e42b4a8b15f7de8f24f2bac1fd2779d2aea1431b2c31/scikit_image-0.26.0-cp313-cp313-win_amd64.whl", hash = "sha256:ac529eb9dbd5954f9aaa2e3fe9a3fd9661bfe24e134c688587d811a0233127f1", size = 11906770, upload-time = "2025-12-20T17:11:25.297Z" }, + { url = "https://files.pythonhosted.org/packages/ad/ec/96941474a18a04b69b6f6562a5bd79bd68049fa3728d3b350976eccb8b93/scikit_image-0.26.0-cp313-cp313-win_arm64.whl", hash = "sha256:a2d211bc355f59725efdcae699b93b30348a19416cc9e017f7b2fb599faf7219", size = 11342506, upload-time = "2025-12-20T17:11:27.399Z" }, + { url = "https://files.pythonhosted.org/packages/03/e5/c1a9962b0cf1952f42d32b4a2e48eed520320dbc4d2ff0b981c6fa508b6b/scikit_image-0.26.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:9eefb4adad066da408a7601c4c24b07af3b472d90e08c3e7483d4e9e829d8c49", size = 12663278, upload-time = "2025-12-20T17:11:29.358Z" }, + { url = "https://files.pythonhosted.org/packages/ae/97/c1a276a59ce8e4e24482d65c1a3940d69c6b3873279193b7ebd04e5ee56b/scikit_image-0.26.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:6caec76e16c970c528d15d1c757363334d5cb3069f9cea93d2bead31820511f3", size = 12405142, upload-time = "2025-12-20T17:11:31.282Z" }, + { url = "https://files.pythonhosted.org/packages/d4/4a/f1cbd1357caef6c7993f7efd514d6e53d8fd6f7fe01c4714d51614c53289/scikit_image-0.26.0-cp313-cp313t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a07200fe09b9d99fcdab959859fe0f7db8df6333d6204344425d476850ce3604", size = 12942086, upload-time = "2025-12-20T17:11:33.683Z" }, + { url = "https://files.pythonhosted.org/packages/5b/6f/74d9fb87c5655bd64cf00b0c44dc3d6206d9002e5f6ba1c9aeb13236f6bf/scikit_image-0.26.0-cp313-cp313t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:92242351bccf391fc5df2d1529d15470019496d2498d615beb68da85fe7fdf37", size = 13265667, upload-time = "2025-12-20T17:11:36.11Z" }, + { url = "https://files.pythonhosted.org/packages/a7/73/faddc2413ae98d863f6fa2e3e14da4467dd38e788e1c23346cf1a2b06b97/scikit_image-0.26.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:52c496f75a7e45844d951557f13c08c81487c6a1da2e3c9c8a39fcde958e02cc", size = 14001966, upload-time = "2025-12-20T17:11:38.55Z" }, + { url = "https://files.pythonhosted.org/packages/02/94/9f46966fa042b5d57c8cd641045372b4e0df0047dd400e77ea9952674110/scikit_image-0.26.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:20ef4a155e2e78b8ab973998e04d8a361d49d719e65412405f4dadd9155a61d9", size = 14359526, upload-time = "2025-12-20T17:11:41.087Z" }, + { url = "https://files.pythonhosted.org/packages/5d/b4/2840fe38f10057f40b1c9f8fb98a187a370936bf144a4ac23452c5ef1baf/scikit_image-0.26.0-cp313-cp313t-win_amd64.whl", hash = "sha256:c9087cf7d0e7f33ab5c46d2068d86d785e70b05400a891f73a13400f1e1faf6a", size = 12287629, upload-time = "2025-12-20T17:11:43.11Z" }, + { url = "https://files.pythonhosted.org/packages/22/ba/73b6ca70796e71f83ab222690e35a79612f0117e5aaf167151b7d46f5f2c/scikit_image-0.26.0-cp313-cp313t-win_arm64.whl", hash = "sha256:27d58bc8b2acd351f972c6508c1b557cfed80299826080a4d803dd29c51b707e", size = 11647755, upload-time = "2025-12-20T17:11:45.279Z" }, + { url = "https://files.pythonhosted.org/packages/51/44/6b744f92b37ae2833fd423cce8f806d2368859ec325a699dc30389e090b9/scikit_image-0.26.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:63af3d3a26125f796f01052052f86806da5b5e54c6abef152edb752683075a9c", size = 12365810, upload-time = "2025-12-20T17:11:47.357Z" }, + { url = "https://files.pythonhosted.org/packages/40/f5/83590d9355191f86ac663420fec741b82cc547a4afe7c4c1d986bf46e4db/scikit_image-0.26.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:ce00600cd70d4562ed59f80523e18cdcc1fae0e10676498a01f73c255774aefd", size = 12075717, upload-time = "2025-12-20T17:11:49.483Z" }, + { url = "https://files.pythonhosted.org/packages/72/48/253e7cf5aee6190459fe136c614e2cbccc562deceb4af96e0863f1b8ee29/scikit_image-0.26.0-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6381edf972b32e4f54085449afde64365a57316637496c1325a736987083e2ab", size = 13161520, upload-time = "2025-12-20T17:11:51.58Z" }, + { url = "https://files.pythonhosted.org/packages/73/c3/cec6a3cbaadfdcc02bd6ff02f3abfe09eaa7f4d4e0a525a1e3a3f4bce49c/scikit_image-0.26.0-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c6624a76c6085218248154cc7e1500e6b488edcd9499004dd0d35040607d7505", size = 13684340, upload-time = "2025-12-20T17:11:53.708Z" }, + { url = "https://files.pythonhosted.org/packages/d4/0d/39a776f675d24164b3a267aa0db9f677a4cb20127660d8bf4fd7fef66817/scikit_image-0.26.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:f775f0e420faac9c2aa6757135f4eb468fb7b70e0b67fa77a5e79be3c30ee331", size = 14203839, upload-time = "2025-12-20T17:11:55.89Z" }, + { url = "https://files.pythonhosted.org/packages/ee/25/2514df226bbcedfe9b2caafa1ba7bc87231a0c339066981b182b08340e06/scikit_image-0.26.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:ede4d6d255cc5da9faeb2f9ba7fedbc990abbc652db429f40a16b22e770bb578", size = 14770021, upload-time = "2025-12-20T17:11:58.014Z" }, + { url = "https://files.pythonhosted.org/packages/8d/5b/0671dc91c0c79340c3fe202f0549c7d3681eb7640fe34ab68a5f090a7c7f/scikit_image-0.26.0-cp314-cp314-win_amd64.whl", hash = "sha256:0660b83968c15293fd9135e8d860053ee19500d52bf55ca4fb09de595a1af650", size = 12023490, upload-time = "2025-12-20T17:12:00.013Z" }, + { url = "https://files.pythonhosted.org/packages/65/08/7c4cb59f91721f3de07719085212a0b3962e3e3f2d1818cbac4eeb1ea53e/scikit_image-0.26.0-cp314-cp314-win_arm64.whl", hash = "sha256:b8d14d3181c21c11170477a42542c1addc7072a90b986675a71266ad17abc37f", size = 11473782, upload-time = "2025-12-20T17:12:01.983Z" }, + { url = "https://files.pythonhosted.org/packages/49/41/65c4258137acef3d73cb561ac55512eacd7b30bb4f4a11474cad526bc5db/scikit_image-0.26.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:cde0bbd57e6795eba83cb10f71a677f7239271121dc950bc060482834a668ad1", size = 12686060, upload-time = "2025-12-20T17:12:03.886Z" }, + { url = "https://files.pythonhosted.org/packages/e7/32/76971f8727b87f1420a962406388a50e26667c31756126444baf6668f559/scikit_image-0.26.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:163e9afb5b879562b9aeda0dd45208a35316f26cc7a3aed54fd601604e5cf46f", size = 12422628, upload-time = "2025-12-20T17:12:05.921Z" }, + { url = "https://files.pythonhosted.org/packages/37/0d/996febd39f757c40ee7b01cdb861867327e5c8e5f595a634e8201462d958/scikit_image-0.26.0-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:724f79fd9b6cb6f4a37864fe09f81f9f5d5b9646b6868109e1b100d1a7019e59", size = 12962369, upload-time = "2025-12-20T17:12:07.912Z" }, + { url = "https://files.pythonhosted.org/packages/48/b4/612d354f946c9600e7dea012723c11d47e8d455384e530f6daaaeb9bf62c/scikit_image-0.26.0-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3268f13310e6857508bd87202620df996199a016a1d281b309441d227c822394", size = 13272431, upload-time = "2025-12-20T17:12:10.255Z" }, + { url = "https://files.pythonhosted.org/packages/0a/6e/26c00b466e06055a086de2c6e2145fe189ccdc9a1d11ccc7de020f2591ad/scikit_image-0.26.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:fac96a1f9b06cd771cbbb3cd96c5332f36d4efd839b1d8b053f79e5887acde62", size = 14016362, upload-time = "2025-12-20T17:12:12.793Z" }, + { url = "https://files.pythonhosted.org/packages/47/88/00a90402e1775634043c2a0af8a3c76ad450866d9fa444efcc43b553ba2d/scikit_image-0.26.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:2c1e7bd342f43e7a97e571b3f03ba4c1293ea1a35c3f13f41efdc8a81c1dc8f2", size = 14364151, upload-time = "2025-12-20T17:12:14.909Z" }, + { url = "https://files.pythonhosted.org/packages/da/ca/918d8d306bd43beacff3b835c6d96fac0ae64c0857092f068b88db531a7c/scikit_image-0.26.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b702c3bb115e1dcf4abf5297429b5c90f2189655888cbed14921f3d26f81d3a4", size = 12413484, upload-time = "2025-12-20T17:12:17.046Z" }, + { url = "https://files.pythonhosted.org/packages/dc/cd/4da01329b5a8d47ff7ec3c99a2b02465a8017b186027590dc7425cee0b56/scikit_image-0.26.0-cp314-cp314t-win_arm64.whl", hash = "sha256:0608aa4a9ec39e0843de10d60edb2785a30c1c47819b67866dd223ebd149acaf", size = 11769501, upload-time = "2025-12-20T17:12:19.339Z" }, +] + [[package]] name = "scikit-learn" version = "1.7.0" @@ -3199,6 +3628,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/49/65/dea992c6a97074f6d8ff9eab34741298cac2ce23e2b6c74fb7d08afdf85c/sentinels-1.1.1-py3-none-any.whl", hash = "sha256:835d3b28f3b47f5284afa4bf2db6e00f2dc5f80f9923d4b7e7aeeeccf6146a11", size = 3744, upload-time = "2025-08-12T07:57:48.858Z" }, ] +[[package]] +name = "sentry-sdk" +version = "2.60.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/54/a2/2e6c090db384cc515069f4f85542bd5baf6786852073020ea73d4a76d3ea/sentry_sdk-2.60.0.tar.gz", hash = "sha256:0bd25e54e78ca02d0be512529fa644bbbf9e8470d7b26371294012d4ca93c978", size = 452946, upload-time = "2026-05-13T13:34:52.516Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/41/f2b800b7f12a05dd48c2a6280d4dd812d1425fc66ed3fe3fd99420c41d1a/sentry_sdk-2.60.0-py3-none-any.whl", hash = "sha256:28a536c03291c8bcb363cf35c611b32738ec118ff64d8d6383b096448ac4c803", size = 475616, upload-time = "2026-05-13T13:34:50.259Z" }, +] + [[package]] name = "setuptools" version = "80.9.0" @@ -3208,6 +3650,60 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, ] +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, +] + +[[package]] +name = "sisl" +version = "0.16.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pyparsing" }, + { name = "scipy" }, + { name = "xarray" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/73/8a/ce69ddd9495b8cd52a99eb631a3176a5818fd5bfcbfde941c9efe1a5c876/sisl-0.16.4.tar.gz", hash = "sha256:bba5fd45a6286d20eabd1232ea83d830d63f343c6212021034c31d53dee928a3", size = 3177153, upload-time = "2026-03-19T08:50:53.972Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/73/7a/9007c5afa91664b5f345f02568deec141c3c8a6e2cfeae2eefc4d3d88d66/sisl-0.16.4-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:cb7f48cb60b3debd53395485066048ff081d182fdf29d8697c62e95feba0df28", size = 4891948, upload-time = "2026-03-19T08:50:24.857Z" }, + { url = "https://files.pythonhosted.org/packages/1b/a4/bb196b01aa330c04566cc299e556d07440e4d781ddf0080c3c09e4da9994/sisl-0.16.4-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c112e5dd7a0d6736a1b851fb5c1f703dee6ca1e76790c09f2858b56cc1f3808f", size = 5680212, upload-time = "2026-03-19T08:50:26.239Z" }, + { url = "https://files.pythonhosted.org/packages/0c/3c/76c8dca17a7298867c05ee3bf787f8c8e90dea04b990ebd5abbfef533094/sisl-0.16.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b29fee46cc73f02c79f7419cf4b998ef5cddd904a28550384fa2ed2c991fd3ca", size = 6091340, upload-time = "2026-03-19T08:50:27.826Z" }, + { url = "https://files.pythonhosted.org/packages/9c/2b/a1d6f7f540f409675a3be73932a7f711fc2a29c5d01dde1f38a17cce7b90/sisl-0.16.4-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:777533765f992f1cf2b1bd391cba2358826f1689ce3ca0fd93ebfb493b17491e", size = 4744733, upload-time = "2026-03-19T08:50:29.527Z" }, + { url = "https://files.pythonhosted.org/packages/7e/61/ec019ead6f34c26a586999b3a565548acce170fa4acb711026f30c42831b/sisl-0.16.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a99fcd7af36162e24b9acf33f39c01c192a5910a2504c7bb8991cce87b182d50", size = 5555109, upload-time = "2026-03-19T08:50:31.194Z" }, + { url = "https://files.pythonhosted.org/packages/57/fb/9683d84d0fe7f0dc83a8d1d42e81e9b117f006ab4e140b42cdbe7578cf3b/sisl-0.16.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:94bb737f35ed9a64aaaa1fbec45f42e32ac884980c57552efcca9c3ef6534e39", size = 5986217, upload-time = "2026-03-19T08:50:32.644Z" }, + { url = "https://files.pythonhosted.org/packages/8a/67/acb7224f88ad16686c9cb58122063389a200298b2ee74f2dfa218ed76ce7/sisl-0.16.4-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:b2b259e6c65446446004afe8a2704f71a88f25333a26e8bec2a272b0790e0dca", size = 4871758, upload-time = "2026-03-19T08:50:34.494Z" }, + { url = "https://files.pythonhosted.org/packages/19/0a/17c235535c6ee253c4e8a25c2995a45a0aea28b81fbc26be402a15a0ce6e/sisl-0.16.4-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:423cf217a3d139d9d6000a694851f9eb5b16a92dbd33cff6490b7f6a90ffe796", size = 5546346, upload-time = "2026-03-19T08:50:35.961Z" }, + { url = "https://files.pythonhosted.org/packages/1c/68/6d908b2590ad0f925a5724ec4827c42b36e72cf2ad733de851d054d98937/sisl-0.16.4-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:464264f96f39c186e8a71b054983884dae94d18c92725a69fb8caecbbb831acd", size = 5973087, upload-time = "2026-03-19T08:50:37.637Z" }, + { url = "https://files.pythonhosted.org/packages/8a/4e/f03ca37ad48ef969b564444bf0e364d1362ee4fb037350ff5777baa62a2e/sisl-0.16.4-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:8bcded1e5b015155d03d756aa87f4748c049f37c2b9913e7ea3006b3b931539a", size = 4541276, upload-time = "2026-03-19T08:50:39.313Z" }, + { url = "https://files.pythonhosted.org/packages/48/9b/967ea173e01c4700430147f5f325b64720f8ab334483db691fbc0fa7a2a8/sisl-0.16.4-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0bb18d3cf141eb75559c99ea3613563a4041757630684375b59a8cb0510ca8ab", size = 5299450, upload-time = "2026-03-19T08:50:40.833Z" }, + { url = "https://files.pythonhosted.org/packages/78/3c/c11255084e02f2100702247eaf8ab3a92ea8a6e5b64a4aa5a792322b6809/sisl-0.16.4-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:b5d9b5b44577cb14e82a1321dd121498e80d1350bcbbadd700eb68e8a9ca2f41", size = 5708911, upload-time = "2026-03-19T08:50:42.613Z" }, + { url = "https://files.pythonhosted.org/packages/cc/c4/c649f133a60379950a61b09c2ebbc4406b6eeb4c9ae8243bafec5ef7f1c5/sisl-0.16.4-cp314-cp314-macosx_14_0_arm64.whl", hash = "sha256:1a5fd37ada238b5e84e0c03de14e3de8983e0cdf0a67de5b5d9cbfb5c3000c32", size = 4900289, upload-time = "2026-03-19T08:50:44.137Z" }, + { url = "https://files.pythonhosted.org/packages/4b/a8/3807ebe875a7422eedcfab00878fac4f59c7723dd82878dfd706d7ad6ba5/sisl-0.16.4-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:70cc5b25a50a9709fcda9d1001c2a451c24c06aba282dccb2f35756f79025395", size = 5570979, upload-time = "2026-03-19T08:50:45.602Z" }, + { url = "https://files.pythonhosted.org/packages/a9/71/35d856ee9285baab216b20113d9f24ca5a77d0dfbb07195d41e3ebc71d53/sisl-0.16.4-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:1d9115c6444ad628f575f190eb15a1bf0e8ff248d48d3ba8da31bfa299f648ec", size = 5980274, upload-time = "2026-03-19T08:50:47.57Z" }, + { url = "https://files.pythonhosted.org/packages/6a/73/dce43b5920137836fa0f428456f357c6282307f23e2aa2022d111bde7e86/sisl-0.16.4-cp314-cp314t-macosx_14_0_arm64.whl", hash = "sha256:676d53a450f2133ebf9342b6c694c9e4cdebb7f6086d7e09e09001e02560b89e", size = 4561158, upload-time = "2026-03-19T08:50:48.953Z" }, + { url = "https://files.pythonhosted.org/packages/1d/9a/8ca7cc4f11641d23bcc8febd73a150c0a48da1dfbaaa1726ea853c9b6dde/sisl-0.16.4-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ea22bbc84c2d2416ddc503578a271523ef00e978ae9fce0092fad37549910c9", size = 5302034, upload-time = "2026-03-19T08:50:50.409Z" }, + { url = "https://files.pythonhosted.org/packages/be/9f/4644f89d1121b98c0fd478ad2c7391faab1c1f6a616157e5379b099baddc/sisl-0.16.4-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:44206fe83252785b43e84800ffe36cc7eb536ab9f1d7e8b537405168703f7469", size = 5714058, upload-time = "2026-03-19T08:50:52.193Z" }, +] + +[package.optional-dependencies] +viz = [ + { name = "ase" }, + { name = "dill" }, + { name = "matplotlib" }, + { name = "netcdf4", version = "1.7.3", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine == 'ARM64' and sys_platform == 'win32'" }, + { name = "netcdf4", version = "1.7.4", source = { registry = "https://pypi.org/simple" }, marker = "platform_machine != 'ARM64' or sys_platform != 'win32'" }, + { name = "nodify" }, + { name = "pathos" }, + { name = "plotly" }, + { name = "scikit-image" }, +] + [[package]] name = "six" version = "1.17.0" @@ -3217,6 +3713,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "smmap" +version = "5.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1f/ea/49c993d6dfdd7338c9b1000a0f36817ed7ec84577ae2e52f890d1a4ff909/smmap-5.0.3.tar.gz", hash = "sha256:4d9debb8b99007ae47165abc08670bd74cb74b5227dda7f643eccc4e9eb5642c", size = 22506, upload-time = "2026-03-09T03:43:26.1Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/d4/59e74daffcb57a07668852eeeb6035af9f32cbfd7a1d2511f17d2fe6a738/smmap-5.0.3-py3-none-any.whl", hash = "sha256:c106e05d5a61449cf6ba9a1e650227ecfb141590d2a98412103ff35d89fc7b2f", size = 24390, upload-time = "2026-03-09T03:43:24.361Z" }, +] + [[package]] name = "spglib" version = "2.6.0" @@ -3315,6 +3820,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638, upload-time = "2025-03-13T13:49:21.846Z" }, ] +[[package]] +name = "tifffile" +version = "2026.3.3" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.12' and platform_machine == 'ARM64' and sys_platform == 'win32'", + "python_full_version < '3.12' and platform_machine != 'ARM64' and sys_platform == 'win32'", + "python_full_version < '3.12' and sys_platform != 'win32'", +] +dependencies = [ + { name = "numpy", marker = "python_full_version < '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/cb/2f6d79c7576e22c116352a801f4c3c8ace5957e9aced862012430b62e14f/tifffile-2026.3.3.tar.gz", hash = "sha256:d9a1266bed6f2ee1dd0abde2018a38b4f8b2935cb843df381d70ac4eac5458b7", size = 388745, upload-time = "2026-03-03T19:14:38.134Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/e4/e804505f87627cd8cdae9c010c47c4485fd8c1ce31a7dd0ab7fcc4707377/tifffile-2026.3.3-py3-none-any.whl", hash = "sha256:e8be15c94273113d31ecb7aa3a39822189dd11c4967e3cc88c178f1ad2fd1170", size = 243960, upload-time = "2026-03-03T19:14:35.808Z" }, +] + +[[package]] +name = "tifffile" +version = "2026.5.15" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.14' and platform_machine == 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_machine != 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and platform_machine == 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and platform_machine != 'ARM64' and platform_python_implementation != 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and platform_machine == 'ARM64' and platform_python_implementation == 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.12' and platform_machine != 'ARM64' and platform_python_implementation == 'PyPy' and sys_platform == 'win32'", + "python_full_version >= '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'win32'", + "python_full_version >= '3.12' and python_full_version < '3.14' and platform_python_implementation != 'PyPy' and sys_platform != 'win32'", + "python_full_version >= '3.12' and platform_python_implementation == 'PyPy' and sys_platform != 'win32'", +] +dependencies = [ + { name = "numpy", marker = "python_full_version >= '3.12'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/40/66/0aef917d525767a40edebe088f8ed6a4417e6eb489c58f6805bfa872636b/tifffile-2026.5.15.tar.gz", hash = "sha256:ee4f3e07ee0d8ff4745a8c735ac2b72caa3173c7d6059b00fdc3ff492a0b635b", size = 429998, upload-time = "2026-05-15T20:04:55.896Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1f/6e/7d8850ff112f8f80d394ca45e89b975a3a43559d47af3137b767669b3294/tifffile-2026.5.15-py3-none-any.whl", hash = "sha256:6715515a53cabc0cefc5c9f13a0ae2c250e63e2ca784ce02d0b6c333810c2a17", size = 266665, upload-time = "2026-05-15T20:04:54.227Z" }, +] + [[package]] name = "torch" version = "2.7.1" @@ -3416,6 +3961,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/71/bd20ffcb7a64c753dc2463489a61bf69d531f308e390ad06390268c4ea04/triton-3.3.1-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3198adb9d78b77818a5388bff89fa72ff36f9da0bc689db2f0a651a67ce6a42", size = 155735832, upload-time = "2025-05-29T23:40:10.522Z" }, ] +[[package]] +name = "typer" +version = "0.25.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-doc" }, + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e4/51/9aed62104cea109b820bbd6c14245af756112017d309da813ef107d42e7e/typer-0.25.1.tar.gz", hash = "sha256:9616eb8853a09ffeabab1698952f33c6f29ffdbceb4eaeecf571880e8d7664cc", size = 122276, upload-time = "2026-04-30T19:32:16.964Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3f/f9/2b3ff4e56e5fa7debfaf9eb135d0da96f3e9a1d5b27222223c7296336e5f/typer-0.25.1-py3-none-any.whl", hash = "sha256:75caa44ed46a03fb2dab8808753ffacdbfea88495e74c85a28c5eefcf5f39c89", size = 58409, upload-time = "2026-04-30T19:32:18.271Z" }, +] + [[package]] name = "typing-extensions" version = "4.15.0" @@ -3464,6 +4024,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a7/c2/fe1e52489ae3122415c51f387e221dd0773709bad6c6cdaa599e8a2c5185/urllib3-2.5.0-py3-none-any.whl", hash = "sha256:e6b01673c0fa6a13e374b50871808eb3bf7046c4b125b216f6bf1cc604cff0dc", size = 129795, upload-time = "2025-06-18T14:07:40.39Z" }, ] +[[package]] +name = "wandb" +version = "0.27.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "gitpython" }, + { name = "packaging" }, + { name = "platformdirs" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sentry-sdk" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8e/31/fe53d06b75ef0a7f2f0ee5931a89f7aedc27d233840b1839616860fed256/wandb-0.27.0.tar.gz", hash = "sha256:579e75300173059f9334e1f513a79ef15f6d9ea5c74e20d695633648cdd02031", size = 41090732, upload-time = "2026-05-14T03:44:08.894Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/5e/2c199e70e636ecfd217cde0bc7469f4511e1d03d0685eb92bfdfce391430/wandb-0.27.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:c156be4851485f3c4160cb6eb2e8991b4cdeffbccefc5636d33cf5e254847365", size = 24886476, upload-time = "2026-05-14T03:43:27.569Z" }, + { url = "https://files.pythonhosted.org/packages/0b/cd/a617c871cd304a9804e56a7ec2ec2c65685bf0091a2b9f91910175a149e2/wandb-0.27.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:20179f38afb0158859a4141d29ac650d3fdbd0cf801a74ce25565c934f03776c", size = 26045779, upload-time = "2026-05-14T03:43:31.999Z" }, + { url = "https://files.pythonhosted.org/packages/10/0a/d3f159a201530b84b72ca5f98c68d1f351c2d9a1864558ed76c811407fae/wandb-0.27.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:626497d7975fa898d0a4a239da7a510483495ca3514510dbe75004a25963af4d", size = 25480764, upload-time = "2026-05-14T03:43:35.922Z" }, + { url = "https://files.pythonhosted.org/packages/5f/6a/8721fcdf71d42639191040a77a585d2982402b1754700cb2ecfc2ca1470a/wandb-0.27.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:f772da7005cc26a2a32b729a16982a583dc68b3d493df6a09d0aa5c5ca5a2060", size = 27256204, upload-time = "2026-05-14T03:43:39.765Z" }, + { url = "https://files.pythonhosted.org/packages/00/5e/279d167ba79fb7a8a43401c9f25efd0f6663ee9bd1eaf5a8578530198888/wandb-0.27.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:63acfc5b994e4a90e4a2fbdee6d45e664da3dd865bb1419942c8995c06c41cf1", size = 25647469, upload-time = "2026-05-14T03:43:44.817Z" }, + { url = "https://files.pythonhosted.org/packages/94/51/a69ac59300e3c813939d0764348959ed2a21e14c668cb1cebcb04010da6a/wandb-0.27.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:17aae6e4a88cd05c00ea8f546220918e3ebb6f8c1c36b70ef04a5ac75f0d7160", size = 27599005, upload-time = "2026-05-14T03:43:50.926Z" }, + { url = "https://files.pythonhosted.org/packages/5f/40/bf510c8758727df020f83b717ebc1fcc1739ed7f6ae1796ebef60bf6f592/wandb-0.27.0-py3-none-win32.whl", hash = "sha256:0bd5659417e386bf6538b5e2ffe6885774c6197f0e4853bfed517d5b0db457f1", size = 25036164, upload-time = "2026-05-14T03:43:54.839Z" }, + { url = "https://files.pythonhosted.org/packages/54/ff/69f88e7d90c22b79bcb911143c13e59742ee192080b21015ff83a5a1f60a/wandb-0.27.0-py3-none-win_amd64.whl", hash = "sha256:89d584b73166eecee96fb446f18d0e45b1aa45aba6a3696296f3f06d7454516b", size = 25036170, upload-time = "2026-05-14T03:43:59.227Z" }, + { url = "https://files.pythonhosted.org/packages/f6/38/f7efd7a87297a55c7e9a331a1dbb5b19e54aeacc11fe6f43f8636a73987c/wandb-0.27.0-py3-none-win_arm64.whl", hash = "sha256:a6c129c311edf210a2b4f2f4acc557eff522628125f5f28ed27df19c16c07079", size = 22972710, upload-time = "2026-05-14T03:44:03.275Z" }, +] + [[package]] name = "wcwidth" version = "0.2.13" @@ -3494,6 +4083,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/07/c6fe3ad3e685340704d314d765b7912993bcb8dc198f0e7a89382d37974b/win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390", size = 4083, upload-time = "2024-12-07T15:28:26.465Z" }, ] +[[package]] +name = "xarray" +version = "2026.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4b/a6/6fe936a798a3a38a79c7422d1a31afd2e9a14690fcb0ccff96bc01f04bf2/xarray-2026.4.0.tar.gz", hash = "sha256:c4ac9a01a945d90d5b1628e2af045099a9d4943536d4f2ee3ae963c3b222d15b", size = 3132311, upload-time = "2026-04-13T19:45:36.688Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/83/6d810a8a9ebc9c307989b418840c20e46907c74d707beb67ab566773e6fc/xarray-2026.4.0-py3-none-any.whl", hash = "sha256:d43751d9fb4a90f9249c30431684f00c41bc874f1edccd862631a40cbc0edf08", size = 1414326, upload-time = "2026-04-13T19:45:34.659Z" }, +] + [[package]] name = "xxhash" version = "3.5.0"