Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ def create_graph(
nlev = int(np.log(max(xy.shape[:2])) / np.log(nx))
nleaf = nx**nlev # leaves at the bottom = nleaf**2

x_extent = xy[:, :, 0].max() - xy[:, :, 0].min()
y_extent = xy[:, :, 1].max() - xy[:, :, 1].min()
mesh_levels = nlev - 1
if n_max_levels:
# Limit the levels in mesh graph
Expand All @@ -254,14 +256,28 @@ def create_graph(

# multi resolution tree levels
G = []
mesh_dims = [] # for tracking per-level dimensions, used for reshape and dm
for lev in range(1, mesh_levels + 1):
n = int(nleaf / (nx**lev))
g = mk_2d_graph(xy, n, n)
n_larger = int(nleaf / (nx**lev))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rounding scheme is not compatible with the existing [1::3, 1::3] parent-selection logic below. For a rectangular grid like xy.shape == (100, 600), this produces level dims (14, 81) and then (4, 27), but sampling the previous level yields (5, 27) parents, so create_graph(..., n_max_levels=2) crashes at line 399 with ValueError: cannot reshape array of size 270 into shape (108,2). The coarse dims need to be derived from the sampled previous level, or the relabel/compose logic needs to change.

# scale n_x and n_y proportionally to physical domain aspect ratio,
# assigning n_larger to whichever direction has greater physical extent
if x_extent >= y_extent:
n_x = n_larger
n_y = max(1, round(n_larger * y_extent / x_extent))
else:
n_y = n_larger
n_x = max(1, round(n_larger * x_extent / y_extent))

if (n_y < 3 or n_x < 3) and lev > 1: # always allow level 1
mesh_levels = lev - 1
break
g = mk_2d_graph(xy, n_x, n_y)
if create_plot:
plot_graph(from_networkx(g), title=f"Mesh graph, level {lev}")
plt.show()

G.append(g)
mesh_dims.append((n_x, n_y))

if hierarchical:
# Relabel nodes of each level with level index first
Expand Down Expand Up @@ -375,11 +391,12 @@ def create_graph(
G_tot = G[0]
for lev in range(1, len(G)):
nodes = list(G[lev - 1].nodes)
n = int(np.sqrt(len(nodes)))
n_x_prev, n_y_prev = mesh_dims[lev - 1]
n_x_curr, n_y_curr = mesh_dims[lev]
ij = (
np.array(nodes)
.reshape((n, n, 2))[1::nx, 1::nx, :]
.reshape(int(n / nx) ** 2, 2)
.reshape((n_x_prev, n_y_prev, 2))[1::nx, 1::nx, :]
.reshape(n_x_curr * n_y_curr, 2)
)
ij = [tuple(x) for x in ij]
G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij)))
Expand Down Expand Up @@ -429,9 +446,9 @@ def create_graph(
vm = G_bottom_mesh.nodes
vm_xy = np.array([xy for _, xy in vm.data("pos")])
# distance between mesh nodes
dm = np.sqrt(
np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2)
)
dx = x_extent / mesh_dims[0][0] # cell width at finest mesh level
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes g2m connectivity for every existing square graph, not just non-square ones. Previously the radius was based on one mesh-node spacing; here it becomes the cell diagonal, which is sqrt(2) larger on square meshes. On an 81x81 square grid, g2m edges jump from 9009 on main to 17757 here. That is a topology regression unless it was explicitly intended and rebase lined everywhere.

dy = y_extent / mesh_dims[0][1] # cell height at finest mesh level
dm = np.sqrt(dx**2 + dy**2) # diagonal of mesh cell (coverage radius basis)

# grid nodes
Nx, Ny = xy.shape[:2]
Expand Down
24 changes: 23 additions & 1 deletion tests/test_graph_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from pathlib import Path

# Third-party
import numpy as np
import pytest
import torch

# First-party
from neural_lam.create_graph import create_graph_from_datastore
from neural_lam.create_graph import create_graph, create_graph_from_datastore
from neural_lam.datastore import DATASTORES
from neural_lam.datastore.base import BaseRegularGridDatastore
from tests.conftest import init_datastore_example
Expand Down Expand Up @@ -117,3 +118,24 @@ def test_graph_creation(datastore_name, graph_name):
assert r.shape[0] == 2 # adjacency matrix uses two rows
elif file_id.endswith("_features"):
assert r.shape[1] == d_features


def test_graph_creation_non_square_aspect_ratio():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test only covers the n_max_levels=1 case, so it misses the actual breakage introduced by the new sizing logic. Please add at least one rectangular multiscale case (n_max_levels=2) and one square-grid regression check for unchanged g2m connectivity, otherwise both failures above still pass CI.

"""
Mesh at level 1 should reflect the domain's aspect ratio, not be square.
"""
Nx, Ny = 100, 600
x = np.linspace(0, 1, Nx)
y = np.linspace(0, 6, Ny)
xx, yy = np.meshgrid(x, y, indexing="ij")
xy = np.stack([xx, yy], axis=-1) # (100, 600, 2)

with tempfile.TemporaryDirectory() as tmpdir:
create_graph(graph_dir_path=tmpdir, xy=xy, n_max_levels=1)
mesh_features = torch.load(f"{tmpdir}/mesh_features.pt")
n_mesh_nodes = mesh_features[0].shape[0]

# With a square mesh, n_mesh_nodes would be n x n.
# With correct aspect ratio, n_y > n_x, so nodes < n_larger^2.
n_larger = int(3 ** int(np.log(max(Nx, Ny)) / np.log(3)) / 3)
assert n_mesh_nodes < n_larger**2