Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def create_graph(
nlev = int(np.log(max(xy.shape[:2])) / np.log(nx))
nleaf = nx**nlev # leaves at the bottom = nleaf**2

x_extent = xy[:, :, 0].max() - xy[:, :, 0].min()
y_extent = xy[:, :, 1].max() - xy[:, :, 1].min()
mesh_levels = nlev - 1
if n_max_levels:
# Limit the levels in mesh graph
Expand All @@ -254,14 +256,35 @@ def create_graph(

# multi resolution tree levels
G = []
# for tracking per-level dimensions, used for reshape and dm
mesh_dims: list[tuple[int, int]] = []
for lev in range(1, mesh_levels + 1):
n = int(nleaf / (nx**lev))
g = mk_2d_graph(xy, n, n)
if lev == 1:
n_larger = int(nleaf / (nx**lev))
# scale n_x and n_y proportionally to physical domain aspect ratio,
# assign n_larger to whichever direction has greater physical extent
if x_extent >= y_extent:
n_x = n_larger
n_y = max(1, round(n_larger * y_extent / x_extent))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please guard degenerate coordinate extents before division. A zero extent on one axis should fail fast with a clear error or use a safe fallback.

else:
n_y = n_larger
n_x = max(1, round(n_larger * x_extent / y_extent))
else:
# derive from actual slice of previous level
n_x_prev, n_y_prev = mesh_dims[lev - 2]
n_x = len(range(1, n_x_prev, nx))
n_y = len(range(1, n_y_prev, nx))

if (n_y < 3 or n_x < 3) and lev > 1: # always allow level 1
mesh_levels = lev - 1
break
g = mk_2d_graph(xy, n_x, n_y)
if create_plot:
plot_graph(from_networkx(g), title=f"Mesh graph, level {lev}")
plt.show()

G.append(g)
mesh_dims.append((n_x, n_y))

if hierarchical:
# Relabel nodes of each level with level index first
Expand Down Expand Up @@ -375,11 +398,12 @@ def create_graph(
G_tot = G[0]
for lev in range(1, len(G)):
nodes = list(G[lev - 1].nodes)
n = int(np.sqrt(len(nodes)))
n_x_prev, n_y_prev = mesh_dims[lev - 1]
n_x_curr, n_y_curr = mesh_dims[lev]
ij = (
np.array(nodes)
.reshape((n, n, 2))[1::nx, 1::nx, :]
.reshape(int(n / nx) ** 2, 2)
.reshape((n_x_prev, n_y_prev, 2))[1::nx, 1::nx, :]
.reshape(n_x_curr * n_y_curr, 2)
)
ij = [tuple(x) for x in ij]
G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij)))
Expand Down Expand Up @@ -429,9 +453,18 @@ def create_graph(
vm = G_bottom_mesh.nodes
vm_xy = np.array([xy for _, xy in vm.data("pos")])
# distance between mesh nodes
dm = np.sqrt(
np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2)
)
dx = x_extent / mesh_dims[0][0] # cell width at finest mesh level
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes g2m connectivity for every existing square graph, not just non-square ones. Previously the radius was based on one mesh-node spacing; here it becomes the cell diagonal, which is sqrt(2) larger on square meshes. On an 81x81 square grid, g2m edges jump from 9009 on main to 17757 here. That is a topology regression unless it was explicitly intended and rebase lined everywhere.

dy = y_extent / mesh_dims[0][1] # cell height at finest mesh level

# compare using mesh_dims because dx dy can cause rounding issue
if mesh_dims[0][0] == mesh_dims[0][1]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This condition should be based on spacing equivalence (for example dx vs dy with tolerance), not only n_x == n_y. Rounded equal dims can happen on non-square domains and cause coverage regressions.

dm = np.sqrt(
np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2)
)
else:
dm = np.sqrt(
dx**2 + dy**2
) # diagonal of mesh cell (coverage radius basis)

# grid nodes
Nx, Ny = xy.shape[:2]
Expand Down
63 changes: 62 additions & 1 deletion tests/test_graph_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from pathlib import Path

# Third-party
import numpy as np
import pytest
import torch

# First-party
from neural_lam.create_graph import create_graph_from_datastore
from neural_lam.create_graph import create_graph, create_graph_from_datastore
from neural_lam.datastore import DATASTORES
from neural_lam.datastore.base import BaseRegularGridDatastore
from tests.conftest import init_datastore_example
Expand Down Expand Up @@ -117,3 +118,63 @@ def test_graph_creation(datastore_name, graph_name):
assert r.shape[0] == 2 # adjacency matrix uses two rows
elif file_id.endswith("_features"):
assert r.shape[1] == d_features


def test_graph_creation_non_square_aspect_ratio():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test only covers the n_max_levels=1 case, so it misses the actual breakage introduced by the new sizing logic. Please add at least one rectangular multiscale case (n_max_levels=2) and one square-grid regression check for unchanged g2m connectivity, otherwise both failures above still pass CI.

"""
Mesh at level 1 should reflect the domain's aspect ratio, not be square.
"""
Nx, Ny = 100, 600
x = np.linspace(0, 1, Nx)
y = np.linspace(0, 6, Ny)
xx, yy = np.meshgrid(x, y, indexing="ij")
xy = np.stack([xx, yy], axis=-1) # (100, 600, 2)

with tempfile.TemporaryDirectory() as tmpdir:
create_graph(graph_dir_path=tmpdir, xy=xy, n_max_levels=1)
mesh_features = torch.load(f"{tmpdir}/mesh_features.pt")
n_mesh_nodes = mesh_features[0].shape[0]

# With a square mesh, n_mesh_nodes would be n x n.
# With correct aspect ratio, n_y > n_x, so nodes < n_larger^2.
n_larger = int(3 ** int(np.log(max(Nx, Ny)) / np.log(3)) / 3)
assert n_mesh_nodes < n_larger**2


def test_graph_creation_multiscale_non_square():
"""
Multi-level rectangular mesh must not crash at reshape.
Mesh should also reflect the domain aspect ratio.
"""
Nx, Ny = 100, 600
x = np.linspace(0, 1, Nx)
y = np.linspace(0, 6, Ny)
xx, yy = np.meshgrid(x, y, indexing="ij")
xy = np.stack([xx, yy], axis=-1)

with tempfile.TemporaryDirectory() as tmpdir:
create_graph(graph_dir_path=tmpdir, xy=xy, n_max_levels=2)
mesh_features = torch.load(f"{tmpdir}/mesh_features.pt")
n_mesh_nodes = mesh_features[0].shape[0]

n_larger = int(3 ** int(np.log(max(Nx, Ny)) / np.log(3)) / 3)
assert n_mesh_nodes < n_larger**2


def test_graph_creation_square_g2m_edges_unchanged():
"""
Square grid g2m edge count must be preserved.
Edge count must use the original square formula over rectangular domains.
"""
Nx, Ny = 81, 81
x = np.linspace(0, 1, Nx)
y = np.linspace(0, 1, Ny)
xx, yy = np.meshgrid(x, y, indexing="ij")
xy = np.stack([xx, yy], axis=-1)

with tempfile.TemporaryDirectory() as tmpdir:
create_graph(graph_dir_path=tmpdir, xy=xy, n_max_levels=1)
g2m = torch.load(f"{tmpdir}/g2m_edge_index.pt")
assert (
g2m.shape[1] == 9009
), f"Square grid g2m edges changed: expected 9009, got {g2m.shape[1]}"