-
Notifications
You must be signed in to change notification settings - Fork 263
Fix square mesh assumption in create_graph to support non-square domains #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e7fadaa
67789e4
a15deaf
6943ca3
b460037
073eb21
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -245,6 +245,8 @@ def create_graph( | |
| nlev = int(np.log(max(xy.shape[:2])) / np.log(nx)) | ||
| nleaf = nx**nlev # leaves at the bottom = nleaf**2 | ||
|
|
||
| x_extent = xy[:, :, 0].max() - xy[:, :, 0].min() | ||
| y_extent = xy[:, :, 1].max() - xy[:, :, 1].min() | ||
| mesh_levels = nlev - 1 | ||
| if n_max_levels: | ||
| # Limit the levels in mesh graph | ||
|
|
@@ -254,14 +256,35 @@ def create_graph( | |
|
|
||
| # multi resolution tree levels | ||
| G = [] | ||
| # for tracking per-level dimensions, used for reshape and dm | ||
| mesh_dims: list[tuple[int, int]] = [] | ||
| for lev in range(1, mesh_levels + 1): | ||
| n = int(nleaf / (nx**lev)) | ||
| g = mk_2d_graph(xy, n, n) | ||
| if lev == 1: | ||
| n_larger = int(nleaf / (nx**lev)) | ||
| # scale n_x and n_y proportionally to physical domain aspect ratio, | ||
| # assign n_larger to whichever direction has greater physical extent | ||
| if x_extent >= y_extent: | ||
| n_x = n_larger | ||
| n_y = max(1, round(n_larger * y_extent / x_extent)) | ||
| else: | ||
| n_y = n_larger | ||
| n_x = max(1, round(n_larger * x_extent / y_extent)) | ||
| else: | ||
| # derive from actual slice of previous level | ||
| n_x_prev, n_y_prev = mesh_dims[lev - 2] | ||
| n_x = len(range(1, n_x_prev, nx)) | ||
| n_y = len(range(1, n_y_prev, nx)) | ||
|
|
||
| if (n_y < 3 or n_x < 3) and lev > 1: # always allow level 1 | ||
| mesh_levels = lev - 1 | ||
| break | ||
| g = mk_2d_graph(xy, n_x, n_y) | ||
| if create_plot: | ||
| plot_graph(from_networkx(g), title=f"Mesh graph, level {lev}") | ||
| plt.show() | ||
|
|
||
| G.append(g) | ||
| mesh_dims.append((n_x, n_y)) | ||
|
|
||
| if hierarchical: | ||
| # Relabel nodes of each level with level index first | ||
|
|
@@ -375,11 +398,12 @@ def create_graph( | |
| G_tot = G[0] | ||
| for lev in range(1, len(G)): | ||
| nodes = list(G[lev - 1].nodes) | ||
| n = int(np.sqrt(len(nodes))) | ||
| n_x_prev, n_y_prev = mesh_dims[lev - 1] | ||
| n_x_curr, n_y_curr = mesh_dims[lev] | ||
| ij = ( | ||
| np.array(nodes) | ||
| .reshape((n, n, 2))[1::nx, 1::nx, :] | ||
| .reshape(int(n / nx) ** 2, 2) | ||
| .reshape((n_x_prev, n_y_prev, 2))[1::nx, 1::nx, :] | ||
| .reshape(n_x_curr * n_y_curr, 2) | ||
| ) | ||
| ij = [tuple(x) for x in ij] | ||
| G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij))) | ||
|
|
@@ -429,9 +453,18 @@ def create_graph( | |
| vm = G_bottom_mesh.nodes | ||
| vm_xy = np.array([xy for _, xy in vm.data("pos")]) | ||
| # distance between mesh nodes | ||
| dm = np.sqrt( | ||
| np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2) | ||
| ) | ||
| dx = x_extent / mesh_dims[0][0] # cell width at finest mesh level | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This changes g2m connectivity for every existing square graph, not just non-square ones. Previously the radius was based on one mesh-node spacing; here it becomes the cell diagonal, which is sqrt(2) larger on square meshes. On an 81x81 square grid, g2m edges jump from 9009 on main to 17757 here. That is a topology regression unless it was explicitly intended and rebase lined everywhere. |
||
| dy = y_extent / mesh_dims[0][1] # cell height at finest mesh level | ||
|
|
||
| # compare using mesh_dims because dx dy can cause rounding issue | ||
| if mesh_dims[0][0] == mesh_dims[0][1]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This condition should be based on spacing equivalence (for example dx vs dy with tolerance), not only n_x == n_y. Rounded equal dims can happen on non-square domains and cause coverage regressions. |
||
| dm = np.sqrt( | ||
| np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2) | ||
| ) | ||
| else: | ||
| dm = np.sqrt( | ||
| dx**2 + dy**2 | ||
| ) # diagonal of mesh cell (coverage radius basis) | ||
|
|
||
| # grid nodes | ||
| Nx, Ny = xy.shape[:2] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,11 +3,12 @@ | |
| from pathlib import Path | ||
|
|
||
| # Third-party | ||
| import numpy as np | ||
| import pytest | ||
| import torch | ||
|
|
||
| # First-party | ||
| from neural_lam.create_graph import create_graph_from_datastore | ||
| from neural_lam.create_graph import create_graph, create_graph_from_datastore | ||
| from neural_lam.datastore import DATASTORES | ||
| from neural_lam.datastore.base import BaseRegularGridDatastore | ||
| from tests.conftest import init_datastore_example | ||
|
|
@@ -117,3 +118,63 @@ def test_graph_creation(datastore_name, graph_name): | |
| assert r.shape[0] == 2 # adjacency matrix uses two rows | ||
| elif file_id.endswith("_features"): | ||
| assert r.shape[1] == d_features | ||
|
|
||
|
|
||
| def test_graph_creation_non_square_aspect_ratio(): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test only covers the n_max_levels=1 case, so it misses the actual breakage introduced by the new sizing logic. Please add at least one rectangular multiscale case (n_max_levels=2) and one square-grid regression check for unchanged g2m connectivity, otherwise both failures above still pass CI. |
||
| """ | ||
| Mesh at level 1 should reflect the domain's aspect ratio, not be square. | ||
| """ | ||
| Nx, Ny = 100, 600 | ||
| x = np.linspace(0, 1, Nx) | ||
| y = np.linspace(0, 6, Ny) | ||
| xx, yy = np.meshgrid(x, y, indexing="ij") | ||
| xy = np.stack([xx, yy], axis=-1) # (100, 600, 2) | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| create_graph(graph_dir_path=tmpdir, xy=xy, n_max_levels=1) | ||
| mesh_features = torch.load(f"{tmpdir}/mesh_features.pt") | ||
| n_mesh_nodes = mesh_features[0].shape[0] | ||
|
|
||
| # With a square mesh, n_mesh_nodes would be n x n. | ||
| # With correct aspect ratio, n_y > n_x, so nodes < n_larger^2. | ||
| n_larger = int(3 ** int(np.log(max(Nx, Ny)) / np.log(3)) / 3) | ||
| assert n_mesh_nodes < n_larger**2 | ||
|
|
||
|
|
||
| def test_graph_creation_multiscale_non_square(): | ||
| """ | ||
| Multi-level rectangular mesh must not crash at reshape. | ||
| Mesh should also reflect the domain aspect ratio. | ||
| """ | ||
| Nx, Ny = 100, 600 | ||
| x = np.linspace(0, 1, Nx) | ||
| y = np.linspace(0, 6, Ny) | ||
| xx, yy = np.meshgrid(x, y, indexing="ij") | ||
| xy = np.stack([xx, yy], axis=-1) | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| create_graph(graph_dir_path=tmpdir, xy=xy, n_max_levels=2) | ||
| mesh_features = torch.load(f"{tmpdir}/mesh_features.pt") | ||
| n_mesh_nodes = mesh_features[0].shape[0] | ||
|
|
||
| n_larger = int(3 ** int(np.log(max(Nx, Ny)) / np.log(3)) / 3) | ||
| assert n_mesh_nodes < n_larger**2 | ||
|
|
||
|
|
||
| def test_graph_creation_square_g2m_edges_unchanged(): | ||
| """ | ||
| Square grid g2m edge count must be preserved. | ||
| Edge count must use the original square formula over rectangular domains. | ||
| """ | ||
| Nx, Ny = 81, 81 | ||
| x = np.linspace(0, 1, Nx) | ||
| y = np.linspace(0, 1, Ny) | ||
| xx, yy = np.meshgrid(x, y, indexing="ij") | ||
| xy = np.stack([xx, yy], axis=-1) | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| create_graph(graph_dir_path=tmpdir, xy=xy, n_max_levels=1) | ||
| g2m = torch.load(f"{tmpdir}/g2m_edge_index.pt") | ||
| assert ( | ||
| g2m.shape[1] == 9009 | ||
| ), f"Square grid g2m edges changed: expected 9009, got {g2m.shape[1]}" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please guard degenerate coordinate extents before division. A zero extent on one axis should fail fast with a clear error or use a safe fallback.