From e7fadaaaa5f83908c70925e20ca04c2938b4ae58 Mon Sep 17 00:00:00 2001 From: osten-antonio Date: Wed, 11 Mar 2026 01:11:50 +0700 Subject: [PATCH 1/4] fix: adjust mesh dimensions to account for non-square aspect ratios in create_graph --- neural_lam/create_graph.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index 2af070525..f0d1f0cfc 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -245,6 +245,8 @@ def create_graph( nlev = int(np.log(max(xy.shape[:2])) / np.log(nx)) nleaf = nx**nlev # leaves at the bottom = nleaf**2 + x_extent = xy[:, :, 0].max() - xy[:, :, 0].min() + y_extent = xy[:, :, 1].max() - xy[:, :, 1].min() mesh_levels = nlev - 1 if n_max_levels: # Limit the levels in mesh graph @@ -254,14 +256,28 @@ def create_graph( # multi resolution tree levels G = [] + mesh_dims = [] # for tracking per-level dimensions, used for reshape and dm for lev in range(1, mesh_levels + 1): - n = int(nleaf / (nx**lev)) - g = mk_2d_graph(xy, n, n) + n_larger = int(nleaf / (nx**lev)) + # scale n_x and n_y proportionally to physical domain aspect ratio, + # assigning n_larger to whichever direction has greater physical extent + if x_extent >= y_extent: + n_x = n_larger + n_y = max(1, round(n_larger * y_extent / x_extent)) + else: + n_y = n_larger + n_x = max(1, round(n_larger * x_extent / y_extent)) + + if (n_y < 3 or n_x < 3) and lev > 1: # always allow level 1 + mesh_levels = lev - 1 + break + g = mk_2d_graph(xy, n_x, n_y) if create_plot: plot_graph(from_networkx(g), title=f"Mesh graph, level {lev}") plt.show() G.append(g) + mesh_dims.append((n_x, n_y)) if hierarchical: # Relabel nodes of each level with level index first @@ -375,11 +391,12 @@ def create_graph( G_tot = G[0] for lev in range(1, len(G)): nodes = list(G[lev - 1].nodes) - n = int(np.sqrt(len(nodes))) + n_x_prev, n_y_prev = mesh_dims[lev - 1] + n_x_curr, n_y_curr = mesh_dims[lev] ij = ( np.array(nodes) - .reshape((n, n, 2))[1::nx, 1::nx, :] - .reshape(int(n / nx) ** 2, 2) + .reshape((n_x_prev, n_y_prev, 2))[1::nx, 1::nx, :] + .reshape(n_x_curr * n_y_curr, 2) ) ij = [tuple(x) for x in ij] G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij))) @@ -429,9 +446,9 @@ def create_graph( vm = G_bottom_mesh.nodes vm_xy = np.array([xy for _, xy in vm.data("pos")]) # distance between mesh nodes - dm = np.sqrt( - np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2) - ) + dx = x_extent / mesh_dims[0][0] # cell width at finest mesh level + dy = y_extent / mesh_dims[0][1] # cell height at finest mesh level + dm = np.sqrt(dx**2 + dy**2) # diagonal of mesh cell (coverage radius basis) # grid nodes Nx, Ny = xy.shape[:2] From 67789e4bee9964ecfd879191705105f03eb176fc Mon Sep 17 00:00:00 2001 From: osten-antonio Date: Wed, 11 Mar 2026 01:15:03 +0700 Subject: [PATCH 2/4] test: add test for non-square mesh aspect ratio in graph creation --- tests/test_graph_creation.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index 93a7a55f4..312d0f885 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -3,11 +3,12 @@ from pathlib import Path # Third-party +import numpy as np import pytest import torch # First-party -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph import create_graph, create_graph_from_datastore from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore from tests.conftest import init_datastore_example @@ -117,3 +118,24 @@ def test_graph_creation(datastore_name, graph_name): assert r.shape[0] == 2 # adjacency matrix uses two rows elif file_id.endswith("_features"): assert r.shape[1] == d_features + + +def test_graph_creation_non_square_aspect_ratio(): + """ + Mesh at level 1 should reflect the domain's aspect ratio, not be square. + """ + Nx, Ny = 100, 600 + x = np.linspace(0, 1, Nx) + y = np.linspace(0, 6, Ny) + xx, yy = np.meshgrid(x, y, indexing="ij") + xy = np.stack([xx, yy], axis=-1) # (100, 600, 2) + + with tempfile.TemporaryDirectory() as tmpdir: + create_graph(graph_dir_path=tmpdir, xy=xy, n_max_levels=1) + mesh_features = torch.load(f"{tmpdir}/mesh_features.pt") + n_mesh_nodes = mesh_features[0].shape[0] + + # With a square mesh, n_mesh_nodes would be n x n. + # With correct aspect ratio, n_y > n_x, so nodes < n_larger^2. + n_larger = int(3 ** int(np.log(max(Nx, Ny)) / np.log(3)) / 3) + assert n_mesh_nodes < n_larger**2 From a15deaf61999e6c7e13c7dfbe78f5f69fcfcef47 Mon Sep 17 00:00:00 2001 From: osten-antonio Date: Fri, 13 Mar 2026 16:17:30 +0700 Subject: [PATCH 3/4] fix: derive coarse mesh dims from prev slice and preserve square grid dm --- neural_lam/create_graph.py | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index f0d1f0cfc..250d84a4a 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -256,17 +256,24 @@ def create_graph( # multi resolution tree levels G = [] - mesh_dims = [] # for tracking per-level dimensions, used for reshape and dm + # for tracking per-level dimensions, used for reshape and dm + mesh_dims: list[tuple[int, int]] = [] for lev in range(1, mesh_levels + 1): - n_larger = int(nleaf / (nx**lev)) - # scale n_x and n_y proportionally to physical domain aspect ratio, - # assigning n_larger to whichever direction has greater physical extent - if x_extent >= y_extent: - n_x = n_larger - n_y = max(1, round(n_larger * y_extent / x_extent)) + if lev == 1: + n_larger = int(nleaf / (nx**lev)) + # scale n_x and n_y proportionally to physical domain aspect ratio, + # assign n_larger to whichever direction has greater physical extent + if x_extent >= y_extent: + n_x = n_larger + n_y = max(1, round(n_larger * y_extent / x_extent)) + else: + n_y = n_larger + n_x = max(1, round(n_larger * x_extent / y_extent)) else: - n_y = n_larger - n_x = max(1, round(n_larger * x_extent / y_extent)) + # derive from actual slice of previous level + n_x_prev, n_y_prev = mesh_dims[lev - 2] + n_x = len(range(1, n_x_prev, nx)) + n_y = len(range(1, n_y_prev, nx)) if (n_y < 3 or n_x < 3) and lev > 1: # always allow level 1 mesh_levels = lev - 1 @@ -448,7 +455,16 @@ def create_graph( # distance between mesh nodes dx = x_extent / mesh_dims[0][0] # cell width at finest mesh level dy = y_extent / mesh_dims[0][1] # cell height at finest mesh level - dm = np.sqrt(dx**2 + dy**2) # diagonal of mesh cell (coverage radius basis) + + # compare using mesh_dims because dx dy can cause rounding issue + if mesh_dims[0][0] == mesh_dims[0][1]: + dm = np.sqrt( + np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2) + ) + else: + dm = np.sqrt( + dx**2 + dy**2 + ) # diagonal of mesh cell (coverage radius basis) # grid nodes Nx, Ny = xy.shape[:2] From 6943ca39b8723c87e3c844fcae28dd0700e4b109 Mon Sep 17 00:00:00 2001 From: osten-antonio Date: Fri, 13 Mar 2026 16:29:05 +0700 Subject: [PATCH 4/4] test: add test for multiscale non-square mesh creation and square grid edge preservation --- tests/test_graph_creation.py | 39 ++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index 312d0f885..0e41a5917 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -139,3 +139,42 @@ def test_graph_creation_non_square_aspect_ratio(): # With correct aspect ratio, n_y > n_x, so nodes < n_larger^2. n_larger = int(3 ** int(np.log(max(Nx, Ny)) / np.log(3)) / 3) assert n_mesh_nodes < n_larger**2 + + +def test_graph_creation_multiscale_non_square(): + """ + Multi-level rectangular mesh must not crash at reshape. + Mesh should also reflect the domain aspect ratio. + """ + Nx, Ny = 100, 600 + x = np.linspace(0, 1, Nx) + y = np.linspace(0, 6, Ny) + xx, yy = np.meshgrid(x, y, indexing="ij") + xy = np.stack([xx, yy], axis=-1) + + with tempfile.TemporaryDirectory() as tmpdir: + create_graph(graph_dir_path=tmpdir, xy=xy, n_max_levels=2) + mesh_features = torch.load(f"{tmpdir}/mesh_features.pt") + n_mesh_nodes = mesh_features[0].shape[0] + + n_larger = int(3 ** int(np.log(max(Nx, Ny)) / np.log(3)) / 3) + assert n_mesh_nodes < n_larger**2 + + +def test_graph_creation_square_g2m_edges_unchanged(): + """ + Square grid g2m edge count must be preserved. + Edge count must use the original square formula over rectangular domains. + """ + Nx, Ny = 81, 81 + x = np.linspace(0, 1, Nx) + y = np.linspace(0, 1, Ny) + xx, yy = np.meshgrid(x, y, indexing="ij") + xy = np.stack([xx, yy], axis=-1) + + with tempfile.TemporaryDirectory() as tmpdir: + create_graph(graph_dir_path=tmpdir, xy=xy, n_max_levels=1) + g2m = torch.load(f"{tmpdir}/g2m_edge_index.pt") + assert ( + g2m.shape[1] == 9009 + ), f"Square grid g2m edges changed: expected 9009, got {g2m.shape[1]}"