Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
88ca8a4
feat(adastra): port ChargE3Net fine-tuning to AMD MI250X on CINES Ada…
speckhard May 19, 2026
097eefb
test(charge3net): structural rotational-equivariance + architecture g…
speckhard May 19, 2026
fcb3236
feat(charge3net): DDP support + wandb soft-fail for Adastra half-node…
speckhard May 20, 2026
5c92beb
feat(submit): parameterize Adastra submit script for pretrained vs fr…
speckhard May 20, 2026
95ff39c
feat(deepdft): LeMat-Rho -> DeepDFT data adapter (TDD)
speckhard May 20, 2026
8d510d2
feat(deepdft): vendored runner + half-node DDP submit script (PR 2/2)
speckhard May 20, 2026
6374ef8
fix(deepdft): paper-faithful single-GPU + drop val-loader eager preload
speckhard May 20, 2026
8657f1a
fix(submit): drop --mem so half-node ChargE3Net jobs stay in shared mode
speckhard May 20, 2026
e8e84c7
fix(data): bounded LRU on _TABLE_CACHE + drop num-workers to 2
speckhard May 20, 2026
21ddeeb
feat(salted): BasisSpec dataclass + TDD tests (PR alpha of stacked st…
speckhard May 21, 2026
1909333
feat(salted): projection + reconstruction layers (PR beta)
speckhard May 21, 2026
cbfeec6
feat(salted): SALTEDModel wrapper + metric integration (PR gamma)
speckhard May 21, 2026
02cdce7
feat(salted): CHGCAR I/O wrapper + VASP hook gate (PR delta)
speckhard May 21, 2026
22809b9
fix(salted): swap orthonormal-approx projection for LSQR
speckhard May 21, 2026
265b62a
feat(salted): D2 dataset projection script (project_chunk + project_d…
speckhard May 21, 2026
8616230
feat(salted): D2 SLURM submit for the LeMat-Rho dataset projection
speckhard May 21, 2026
274ce74
fix(submit): bump NCCL timeout + heartbeat tolerance for ChargE3Net DDP
speckhard May 21, 2026
0ec5177
feat(salted): rholearn data-format adapter (D3)
speckhard May 21, 2026
c99c01a
chore(deps): add metatensor + chemfiles for the rholearn adapter
speckhard May 21, 2026
96d8292
feat(graph2mat): BasisSpec -> PointBasis adapter (D5 PR zeta-alpha)
speckhard May 22, 2026
2a94222
fix(submit): trim D2 SLURM CPU ask to keep it in genoa-shared
speckhard May 22, 2026
a5a143f
feat(graph2mat): per-atom coefficient projection (D5 PR zeta-beta)
speckhard May 22, 2026
05d2837
feat(graph2mat): Graph2MatModel wrapper with stub mode (D5 PR zeta-ga…
speckhard May 22, 2026
714b5aa
feat(graph2mat): shared CHGCAR IO surface (D5 PR zeta-delta)
speckhard May 22, 2026
10003be
chore(graph2mat): mark arm parked, cite the projection blocker
speckhard May 25, 2026
9ae7ef2
feat(eval): per-structure density model eval script (D7-alpha)
speckhard May 25, 2026
ca43941
feat(eval): cross-arm density comparison table (D8)
speckhard May 25, 2026
f034891
fix(salted): blake2b stub seed so atoms past index 0 contribute
speckhard May 25, 2026
9882246
feat(eval): ChargE3Net grid prediction with probe batching (D7-beta1)
speckhard May 26, 2026
6581505
feat(eval): DeepDFT grid prediction wired (D7-beta2)
speckhard May 26, 2026
5616943
feat(salted): SchNet-style baseline coefficient predictor + train loo…
speckhard May 26, 2026
edab3dd
feat(salted): wire D6 baseline ckpt into SALTEDModel inference
speckhard May 26, 2026
cdac029
fix(deepdft): cut RotatingPoolData + num_workers for LeMat-Rho grid s…
speckhard May 26, 2026
c40e32c
feat(scf): scf_speedup_run.py driver — predict CHGCAR, submit paired …
speckhard Jun 1, 2026
c88c838
fix(scf): per-row error handling + JSONL manifest + resumable runs (P…
speckhard Jun 1, 2026
0937ffa
fix(scf): nested CHGCAR dir layout + non-degenerate test row (P4 polish)
speckhard Jun 1, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions charge3net_ft/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
opened tables per chunk file so each file is read from disk only once per worker.
"""

import collections
import json
import sys
from functools import partial
Expand Down Expand Up @@ -55,10 +56,20 @@
# ---------------------------------------------------------------------------
_SYMBOL_TO_Z = {s: z for z, s in enumerate(ase.data.chemical_symbols)}

# Process-local table cache: keyed by file index, populated on first access.
# Each DataLoader worker process has its own cache, so each chunk file is read
# from disk at most once per worker instead of once per __getitem__ call.
_TABLE_CACHE: dict = {}
# Process-local LRU table cache: keyed by file index, populated on first access.
# Each DataLoader worker has its own cache (workers fork the parent), so each
# chunk file is read from disk at most once per worker per cache cycle.
#
# Bounded LRU because the previous unbounded version OOM-killed jobs 4971293
# and 4971343 at MaxRSS=35 GB/rank. Per-chunk decompressed pyarrow tables
# weigh ~2 GB (the compressed_charge_density JSON strings inflate 6x from
# disk). With 8 workers x 4 DDP ranks = 32 workers, an unbounded cache grew
# to ~140 GB total in 6 h.
#
# Cap of 5 chunks per worker keeps each worker's cache around 10 GB worst
# case, well under any per-rank memory budget. OrderedDict gives O(1) LRU.
_TABLE_CACHE_MAX_CHUNKS = 5
_TABLE_CACHE: "collections.OrderedDict[int, object]" = collections.OrderedDict()


def _parse_grid_json(json_str: str) -> np.ndarray:
Expand Down Expand Up @@ -131,7 +142,9 @@ def _build_parquet_index(parquet_dir: Path) -> tuple:
index.append((fi, ri))

n_valid = len(index)
print(f"LeMatRhoDataset: {n_valid}/{n_total} valid rows indexed from {len(file_paths)} files")
print(
f"LeMatRhoDataset: {n_valid}/{n_total} valid rows indexed from {len(file_paths)} files"
)
return file_paths, index


Expand Down Expand Up @@ -186,11 +199,21 @@ def _read_row(self, idx: int) -> dict:
"""
Read a single row from disk via its index entry.

Uses a process-local cache (_TABLE_CACHE) so each chunk file is
loaded from disk only once per worker, not on every __getitem__ call.
Uses a process-local LRU cache (_TABLE_CACHE) so each chunk file is
loaded from disk at most once per worker per cache cycle. Cache is
capped at _TABLE_CACHE_MAX_CHUNKS entries; on a miss past capacity
the least-recently-used chunk is evicted. Re-access of a present
entry promotes it to most-recent so the running shuffled-access
pattern from RandomSampler doesn't constantly thrash.
"""
fi, ri = self._index[idx]
if fi not in _TABLE_CACHE:
if fi in _TABLE_CACHE:
# Hit: bump to most-recent and return.
_TABLE_CACHE.move_to_end(fi)
else:
# Miss: evict LRU if at capacity, then read.
if len(_TABLE_CACHE) >= _TABLE_CACHE_MAX_CHUNKS:
_TABLE_CACHE.popitem(last=False)
_TABLE_CACHE[fi] = pq.read_table(self._file_paths[fi], columns=_COLUMNS)
table = _TABLE_CACHE[fi]
row = {}
Expand Down Expand Up @@ -230,6 +253,7 @@ def build_dataloaders(
num_workers: int = 4,
seed: int = 42,
pin_memory: bool = False,
distributed: bool = False,
) -> tuple:
"""
Build train, validation, and test DataLoaders.
Expand Down Expand Up @@ -298,10 +322,27 @@ def build_dataloaders(

collate_fn = partial(collate_list_of_dicts, pin_memory=pin_memory)

# DDP path: shard the training set across ranks via DistributedSampler.
# Val/test stay non-distributed (each rank evaluates the whole set; only
# rank 0 reports). This wastes V+T compute but keeps eval simple and
# rank-agnostic. The data is tiny (5%+5% of 65k) so it's fine.
train_sampler = None
if distributed:
from torch.utils.data.distributed import DistributedSampler

train_sampler = DistributedSampler(
train_subset,
shuffle=True,
seed=seed,
drop_last=True,
)

train_loader = DataLoader(
train_subset,
batch_size=batch_size,
shuffle=True,
# shuffle and sampler are mutually exclusive in DataLoader.
shuffle=(train_sampler is None),
sampler=train_sampler,
num_workers=num_workers,
collate_fn=collate_fn,
pin_memory=pin_memory,
Expand Down
Loading