From f51beba9a8d6413c9a59541562da16f50f7451f5 Mon Sep 17 00:00:00 2001 From: prajwal Date: Tue, 14 Apr 2026 00:21:33 +0530 Subject: [PATCH 1/8] feat: add create_graph_with_wmg.py CLI using weather-model-graphs --- neural_lam/create_graph.py | 9 ++ neural_lam/create_graph_with_wmg.py | 188 ++++++++++++++++++++++++++++ pyproject.toml | 6 +- tests/test_graph_creation.py | 103 +++++++++++++++ 4 files changed, 305 insertions(+), 1 deletion(-) create mode 100644 neural_lam/create_graph_with_wmg.py diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index b24b96c23..03a9efcf7 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -1,5 +1,6 @@ # Standard library import os +import warnings from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from typing import Optional @@ -561,6 +562,14 @@ def create_graph_from_datastore( def cli(input_args=None): + warnings.warn( + "create_graph.py is deprecated and will be removed in a future " + "version. Use create_graph_with_wmg.py instead, which delegates " + "graph creation to weather-model-graphs (wmg). See " + "https://github.com/mllam/neural-lam/issues/384 for details.", + DeprecationWarning, + stacklevel=2, + ) parser = ArgumentParser( description="Graph generation for neural-lam", formatter_class=ArgumentDefaultsHelpFormatter, diff --git a/neural_lam/create_graph_with_wmg.py b/neural_lam/create_graph_with_wmg.py new file mode 100644 index 000000000..068b8d712 --- /dev/null +++ b/neural_lam/create_graph_with_wmg.py @@ -0,0 +1,188 @@ +# Standard library +import os +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +# Third-party +import numpy as np +import weather_model_graphs as wmg + +# Local +from .config import load_config_and_datastore +from .datastore.base import BaseRegularGridDatastore + +ARCHETYPE_FUNCTIONS = { + "keisler": wmg.create.archetype.create_keisler_graph, + "graphcast": wmg.create.archetype.create_graphcast_graph, + "hierarchical": wmg.create.archetype.create_oskarsson_hierarchical_graph, +} + + +def _estimate_mesh_node_distance(xy): + """Estimate a reasonable mesh node distance from grid coordinates. + + Uses the average grid spacing to produce a mesh that is roughly 3x + coarser than the grid, similar to the default behaviour of the old + ``create_graph.py`` script. + + Parameters + ---------- + xy : np.ndarray + Grid coordinates of shape ``(N, 2)``. + + Returns + ------- + float + Estimated mesh node distance in coordinate units. + """ + x_range = np.ptp(xy[:, 0]) + y_range = np.ptp(xy[:, 1]) + n_points = len(xy) + # avg grid spacing ≈ sqrt(area / n_points) + avg_spacing = np.sqrt(x_range * y_range / n_points) + # mesh is ~3x coarser than the grid + return float(avg_spacing * 3) + + +def create_graph_from_datastore( + datastore, + output_root_path, + archetype="keisler", + mesh_node_distance=None, + level_refinement_factor=3, + max_num_levels=None, +): + """Create graph using weather-model-graphs and save in neural-lam format. + + Parameters + ---------- + datastore : BaseRegularGridDatastore + Datastore providing grid coordinates. + output_root_path : str + Directory where the .pt graph files will be saved. + archetype : str + Graph archetype to create: ``"keisler"``, ``"graphcast"``, or + ``"hierarchical"``. + mesh_node_distance : float or None + Distance between created mesh nodes (in coordinate units). If None, + automatically estimated from the grid spacing. + level_refinement_factor : int + Refinement factor between mesh hierarchy levels. Only used for + ``"graphcast"`` and ``"hierarchical"`` archetypes. + max_num_levels : int or None + Maximum number of mesh hierarchy levels. Only used for ``"graphcast"`` + and ``"hierarchical"`` archetypes. + """ + if not isinstance(datastore, BaseRegularGridDatastore): + raise NotImplementedError( + "Only graph creation for BaseRegularGridDatastore is supported" + ) + + if archetype not in ARCHETYPE_FUNCTIONS: + raise ValueError( + f"Unknown archetype '{archetype}'. " + f"Must be one of: {list(ARCHETYPE_FUNCTIONS.keys())}" + ) + + xy = datastore.get_xy(category="state", stacked=False) + + # wmg expects coords as 2D array of shape (num_nodes, 2), but the + # datastore may return a 3D array of shape (Nx, Ny, 2) when + # stacked=False. Reshape to (N, 2) for wmg. + xy = np.array(xy) + if xy.ndim == 3: + xy = xy.reshape(-1, 2) + + if mesh_node_distance is None: + mesh_node_distance = _estimate_mesh_node_distance(xy) + + # Build keyword arguments for the archetype function + archetype_kwargs = dict( + coords=xy, + mesh_node_distance=mesh_node_distance, + return_components=True, + ) + + # Only multiscale/hierarchical archetypes accept these parameters + if archetype in ("graphcast", "hierarchical"): + archetype_kwargs["level_refinement_factor"] = level_refinement_factor + archetype_kwargs["max_num_levels"] = max_num_levels + + archetype_fn = ARCHETYPE_FUNCTIONS[archetype] + graph_components = archetype_fn(**archetype_kwargs) + + hierarchical = archetype == "hierarchical" + + wmg.save.to_neural_lam( + graph_components=graph_components, + output_directory=output_root_path, + hierarchical=hierarchical, + ) + + +def cli(input_args=None): + """Command-line interface for graph creation using weather-model-graphs.""" + parser = ArgumentParser( + description="Graph generation for neural-lam using " + "weather-model-graphs (wmg)", + formatter_class=ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config_path", + type=str, + help="Path to neural-lam configuration file", + ) + parser.add_argument( + "--name", + type=str, + default="multiscale", + help="Name to save graph as (used as subdirectory name)", + ) + parser.add_argument( + "--archetype", + type=str, + default="keisler", + choices=["keisler", "graphcast", "hierarchical"], + help="Graph archetype to create", + ) + parser.add_argument( + "--mesh_node_distance", + type=float, + default=None, + help="Distance between mesh nodes (in coordinate units). " + "If not set, estimated automatically from grid spacing.", + ) + parser.add_argument( + "--level_refinement_factor", + type=int, + default=3, + help="Refinement factor between mesh hierarchy levels " + "(only used for graphcast and hierarchical)", + ) + parser.add_argument( + "--max_num_levels", + type=int, + default=None, + help="Maximum number of mesh levels " + "(only used for graphcast and hierarchical)", + ) + args = parser.parse_args(input_args) + + assert ( + args.config_path is not None + ), "Specify your config with --config_path" + + # Load neural-lam configuration and datastore to use + _, datastore = load_config_and_datastore(config_path=args.config_path) + + create_graph_from_datastore( + datastore=datastore, + output_root_path=os.path.join(datastore.root_path, "graph", args.name), + archetype=args.archetype, + mesh_node_distance=args.mesh_node_distance, + level_refinement_factor=args.level_refinement_factor, + max_num_levels=args.max_num_levels, + ) + + +if __name__ == "__main__": + cli() diff --git a/pyproject.toml b/pyproject.toml index bccec0828..9f9a752a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "matplotlib>=3.7.0", "plotly>=5.15.0", "torch>=2.3.0", - "torch-geometric==2.3.1", + "torch-geometric>=2.5.3", "parse>=1.20.2", "dataclass-wizard<0.31.0", "mllam-data-prep>=0.5.0", @@ -38,9 +38,13 @@ dependencies = [ "boto3>=1.35.32", "nvidia-ml-py>=13.580.82", "pillow>=9.0.0", + "weather-model-graphs[pytorch]>=0.3.0", ] requires-python = ">=3.10" +[project.scripts] +create_graph_with_wmg = "neural_lam.create_graph_with_wmg:cli" + [dependency-groups] dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"] diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index 93a7a55f4..76a4f7bdc 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -1,5 +1,6 @@ # Standard library import tempfile +import warnings from pathlib import Path # Third-party @@ -8,6 +9,9 @@ # First-party from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import ( + create_graph_from_datastore as wmg_create_graph_from_datastore, +) from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore from tests.conftest import init_datastore_example @@ -117,3 +121,102 @@ def test_graph_creation(datastore_name, graph_name): assert r.shape[0] == 2 # adjacency matrix uses two rows elif file_id.endswith("_features"): assert r.shape[1] == d_features + + +@pytest.mark.parametrize("archetype", ["keisler", "graphcast", "hierarchical"]) +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_wmg_graph_creation(datastore_name, archetype): + """Check that graph creation via weather-model-graphs produces the + expected .pt files with correct shapes and types.""" + datastore = init_datastore_example(datastore_name) + + if not isinstance(datastore, BaseRegularGridDatastore): + pytest.skip( + f"Skipping test for {datastore_name} as it is not a regular " + "grid datastore." + ) + + hierarchical = archetype == "hierarchical" + + required_graph_files = [ + "m2m_edge_index.pt", + "g2m_edge_index.pt", + "m2g_edge_index.pt", + "m2m_features.pt", + "g2m_features.pt", + "m2g_features.pt", + "mesh_features.pt", + ] + if hierarchical: + required_graph_files.extend( + [ + "mesh_up_edge_index.pt", + "mesh_down_edge_index.pt", + "mesh_up_features.pt", + "mesh_down_features.pt", + ] + ) + + d_features = 3 + d_mesh_static = 2 + + with tempfile.TemporaryDirectory() as tmpdir: + graph_dir_path = Path(tmpdir) / "graph" / archetype + + wmg_create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + archetype=archetype, + ) + + assert graph_dir_path.exists() + + # check that all the required files are present + for file_name in required_graph_files: + assert (graph_dir_path / file_name).exists() + + # try to load each and ensure they have the right shape + for file_name in required_graph_files: + file_id = Path(file_name).stem + result = torch.load(graph_dir_path / file_name, weights_only=True) + + if file_id.startswith("g2m") or file_id.startswith("m2g"): + assert isinstance(result, torch.Tensor) + + if file_id.endswith("_index"): + assert result.shape[0] == 2 + elif file_id.endswith("_features"): + assert result.shape[1] == d_features + + elif file_id.startswith("m2m") or file_id.startswith("mesh"): + assert isinstance(result, list) + + for r in result: + assert isinstance(r, torch.Tensor) + + if file_id == "mesh_features": + assert r.shape[1] == d_mesh_static + elif file_id.endswith("_index"): + assert r.shape[0] == 2 + elif file_id.endswith("_features"): + assert r.shape[1] == d_features + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_old_create_graph_deprecation_warning(datastore_name): + """Check that the old create_graph CLI emits a deprecation warning.""" + # First-party + from neural_lam.create_graph import cli as old_cli + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + try: + old_cli(["--config_path", "nonexistent.yaml"]) + except Exception: + pass # We only care about the warning, not the error + + deprecation_warnings = [ + x for x in w if issubclass(x.category, DeprecationWarning) + ] + assert len(deprecation_warnings) >= 1 + assert "create_graph_with_wmg" in str(deprecation_warnings[0].message) From dc77271d8ec0aeb3dfad66815471902b26a02dbd Mon Sep 17 00:00:00 2001 From: prajwal Date: Tue, 14 Apr 2026 01:09:07 +0530 Subject: [PATCH 2/8] chore: point weather-model-graphs dep to PR #123 branch Update pyproject.toml to install weather-model-graphs from the issue-384/to-neural-lam branch of the fork, so that CI and reviewers can test the neural-lam side against the unreleased to_neural_lam() changes before wmg PR #123 is merged. Will revert to a versioned PyPI dependency once PR #123 is released. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9f9a752a8..1cbf201fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "boto3>=1.35.32", "nvidia-ml-py>=13.580.82", "pillow>=9.0.0", - "weather-model-graphs[pytorch]>=0.3.0", + "weather-model-graphs[pytorch] @ git+https://github.com/prajwal-tech07/weather-model-graphs@issue-384/to-neural-lam", ] requires-python = ">=3.10" From 2500ade833d5a1588373c4c61daeed863b289964 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 14 Apr 2026 07:36:32 +0200 Subject: [PATCH 3/8] fix pyproject.toml ref to wmg branch --- pyproject.toml | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1cbf201fe..0c6ee78fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ dependencies = [ "boto3>=1.35.32", "nvidia-ml-py>=13.580.82", "pillow>=9.0.0", - "weather-model-graphs[pytorch] @ git+https://github.com/prajwal-tech07/weather-model-graphs@issue-384/to-neural-lam", + "weather-model-graphs", ] requires-python = ">=3.10" @@ -134,3 +134,6 @@ exclude = [ ".venv/", "venv/", ] + +[tool.uv.sources] +weather-model-graphs = { git = "https://github.com/prajwal-tech07/weather-model-graphs", rev = "issue-384/to-neural-lam" } From d6da819cb6f27c9faa777da21fbb51791c6989d4 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 14 Apr 2026 08:42:09 +0200 Subject: [PATCH 4/8] make torch version detection in ci more robust --- .github/workflows/install-and-test.yml | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/.github/workflows/install-and-test.yml b/.github/workflows/install-and-test.yml index 8e6be806f..8cb9b3c08 100644 --- a/.github/workflows/install-and-test.yml +++ b/.github/workflows/install-and-test.yml @@ -62,7 +62,20 @@ jobs: - name: Determine PyTorch version for neural-lam (using pip dry-run) if: matrix.package_manager == 'pip' run: | - TORCH_VERSION=$(python -m pip install --dry-run "." | grep "Would install" | grep -o 'torch-[0-9.]*' | awk -F'-' '{print $2}' | tail -n 1) + # 1. Run pip install in simulation mode for the current directory (".") + # --disable-pip-version-check stops the "[notice] A new release of pip is available" message. + # 2>/dev/null silences any other unexpected errors or warnings by throwing them away. + # 2. grep "Would install": Filter the output to only look at the summary of what pip intends to do. + # 3. grep -o 'torch-[0-9.]*': Extract just the package name and version for torch. + # 4. awk -F'-' '{print $2}': Strip the "torch-" prefix to get just the version number. + # 5. grep -v '^$': Remove any empty lines that might have been produced. + # 6. tail -n 1: Grab only the final result. + TORCH_VERSION=$(python -m pip install --disable-pip-version-check --dry-run "." 2>/dev/null \ + | grep "Would install" \ + | grep -o 'torch-[0-9.]*' \ + | awk -F'-' '{print $2}' \ + | grep -v '^$' \ + | tail -n 1) echo "Torch version detected: $TORCH_VERSION" echo "TORCH_VERSION=$TORCH_VERSION" >> $GITHUB_ENV From 2c679e70f93827236b59b03be1da70ceb8e510f0 Mon Sep 17 00:00:00 2001 From: prajwal Date: Tue, 14 Apr 2026 15:20:46 +0530 Subject: [PATCH 5/8] Address review: rename to _estimate_grid_node_spacing, expose grid_mesh_ratio, use stacked=True, add return_components comment, add README entry --- README.md | 23 ++++++++++++++ neural_lam/create_graph_with_wmg.py | 47 ++++++++++++++++------------- 2 files changed, 49 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 17e269841..c64b03cbd 100644 --- a/README.md +++ b/README.md @@ -390,6 +390,9 @@ python -m neural_lam.datastore.npyfilesmeps.compute_standardization_stats **Note:** The `create_graph` command below is deprecated and will be removed +> in a future release. Please use `create_graph_with_wmg` (see below) instead. + Run `python -m neural_lam.create_graph` with suitable options to generate the graph you want to use (see `python -m neural_lam.create_graph --help` for a list of options). The graphs used for the different models in the [paper](#graph-based-neural-weather-prediction-for-limited-area-modeling) can be created as: @@ -399,6 +402,26 @@ The graphs used for the different models in the [paper](#graph-based-neural-weat The graph-related files are stored in a directory called `graphs`. +### Graph creation with weather-model-graphs + +The recommended way to create graphs is with the `create_graph_with_wmg` +command, which delegates graph construction to +[weather-model-graphs](https://github.com/mllam/weather-model-graphs): + +```bash +python -m neural_lam.create_graph_with_wmg --config_path --archetype +``` + +Available archetypes: + +* **keisler** (default): `python -m neural_lam.create_graph_with_wmg --config_path --archetype keisler` +* **graphcast**: `python -m neural_lam.create_graph_with_wmg --config_path --archetype graphcast` +* **hierarchical**: `python -m neural_lam.create_graph_with_wmg --config_path --archetype hierarchical` + +Run `python -m neural_lam.create_graph_with_wmg --help` for the full list of +options (e.g. `--mesh_node_distance`, `--grid_mesh_ratio`, +`--level_refinement_factor`, `--max_num_levels`). + ## Logging your experiments ### Weights & Biases Integration diff --git a/neural_lam/create_graph_with_wmg.py b/neural_lam/create_graph_with_wmg.py index 068b8d712..987cf733f 100644 --- a/neural_lam/create_graph_with_wmg.py +++ b/neural_lam/create_graph_with_wmg.py @@ -17,12 +17,8 @@ } -def _estimate_mesh_node_distance(xy): - """Estimate a reasonable mesh node distance from grid coordinates. - - Uses the average grid spacing to produce a mesh that is roughly 3x - coarser than the grid, similar to the default behaviour of the old - ``create_graph.py`` script. +def _estimate_grid_node_spacing(xy): + """Estimate the average grid node spacing from grid coordinates. Parameters ---------- @@ -32,15 +28,13 @@ def _estimate_mesh_node_distance(xy): Returns ------- float - Estimated mesh node distance in coordinate units. + Estimated average grid node spacing in coordinate units. """ x_range = np.ptp(xy[:, 0]) y_range = np.ptp(xy[:, 1]) n_points = len(xy) # avg grid spacing ≈ sqrt(area / n_points) - avg_spacing = np.sqrt(x_range * y_range / n_points) - # mesh is ~3x coarser than the grid - return float(avg_spacing * 3) + return float(np.sqrt(x_range * y_range / n_points)) def create_graph_from_datastore( @@ -48,6 +42,7 @@ def create_graph_from_datastore( output_root_path, archetype="keisler", mesh_node_distance=None, + grid_mesh_ratio=3.0, level_refinement_factor=3, max_num_levels=None, ): @@ -64,7 +59,10 @@ def create_graph_from_datastore( ``"hierarchical"``. mesh_node_distance : float or None Distance between created mesh nodes (in coordinate units). If None, - automatically estimated from the grid spacing. + automatically estimated as ``grid_mesh_ratio * grid_spacing``. + grid_mesh_ratio : float + Ratio of mesh node distance to grid node spacing. Only used when + ``mesh_node_distance`` is None. Default is 3.0. level_refinement_factor : int Refinement factor between mesh hierarchy levels. Only used for ``"graphcast"`` and ``"hierarchical"`` archetypes. @@ -83,19 +81,17 @@ def create_graph_from_datastore( f"Must be one of: {list(ARCHETYPE_FUNCTIONS.keys())}" ) - xy = datastore.get_xy(category="state", stacked=False) - - # wmg expects coords as 2D array of shape (num_nodes, 2), but the - # datastore may return a 3D array of shape (Nx, Ny, 2) when - # stacked=False. Reshape to (N, 2) for wmg. + xy = datastore.get_xy(category="state", stacked=True) xy = np.array(xy) - if xy.ndim == 3: - xy = xy.reshape(-1, 2) if mesh_node_distance is None: - mesh_node_distance = _estimate_mesh_node_distance(xy) + grid_spacing = _estimate_grid_node_spacing(xy) + mesh_node_distance = grid_spacing * grid_mesh_ratio - # Build keyword arguments for the archetype function + # Build keyword arguments for the archetype function. + # return_components=True is required because wmg.save.to_neural_lam() + # expects the graph as separate g2m, m2g and m2m sub-graph components + # rather than a single merged graph. archetype_kwargs = dict( coords=xy, mesh_node_distance=mesh_node_distance, @@ -149,7 +145,15 @@ def cli(input_args=None): type=float, default=None, help="Distance between mesh nodes (in coordinate units). " - "If not set, estimated automatically from grid spacing.", + "If not set, estimated automatically from grid spacing " + "and --grid_mesh_ratio.", + ) + parser.add_argument( + "--grid_mesh_ratio", + type=float, + default=3.0, + help="Ratio of mesh node distance to grid node spacing. " + "Only used when --mesh_node_distance is not set.", ) parser.add_argument( "--level_refinement_factor", @@ -179,6 +183,7 @@ def cli(input_args=None): output_root_path=os.path.join(datastore.root_path, "graph", args.name), archetype=args.archetype, mesh_node_distance=args.mesh_node_distance, + grid_mesh_ratio=args.grid_mesh_ratio, level_refinement_factor=args.level_refinement_factor, max_num_levels=args.max_num_levels, ) From 8a690e443f9e1e9539ed16b14628daf5e2bcc817 Mon Sep 17 00:00:00 2001 From: prajwal Date: Wed, 15 Apr 2026 17:08:52 +0530 Subject: [PATCH 6/8] Rename grid_mesh_ratio to grid_mesh_spacing_ratio, revert torch-geometric to ==2.3.1 --- README.md | 2 +- neural_lam/create_graph_with_wmg.py | 14 +++++++------- pyproject.toml | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index c64b03cbd..c4fa0b961 100644 --- a/README.md +++ b/README.md @@ -419,7 +419,7 @@ Available archetypes: * **hierarchical**: `python -m neural_lam.create_graph_with_wmg --config_path --archetype hierarchical` Run `python -m neural_lam.create_graph_with_wmg --help` for the full list of -options (e.g. `--mesh_node_distance`, `--grid_mesh_ratio`, +options (e.g. `--mesh_node_distance`, `--grid_mesh_spacing_ratio`, `--level_refinement_factor`, `--max_num_levels`). ## Logging your experiments diff --git a/neural_lam/create_graph_with_wmg.py b/neural_lam/create_graph_with_wmg.py index 987cf733f..a5a2b12ca 100644 --- a/neural_lam/create_graph_with_wmg.py +++ b/neural_lam/create_graph_with_wmg.py @@ -42,7 +42,7 @@ def create_graph_from_datastore( output_root_path, archetype="keisler", mesh_node_distance=None, - grid_mesh_ratio=3.0, + grid_mesh_spacing_ratio=3.0, level_refinement_factor=3, max_num_levels=None, ): @@ -59,8 +59,8 @@ def create_graph_from_datastore( ``"hierarchical"``. mesh_node_distance : float or None Distance between created mesh nodes (in coordinate units). If None, - automatically estimated as ``grid_mesh_ratio * grid_spacing``. - grid_mesh_ratio : float + automatically estimated as ``grid_mesh_spacing_ratio * grid_spacing``. + grid_mesh_spacing_ratio : float Ratio of mesh node distance to grid node spacing. Only used when ``mesh_node_distance`` is None. Default is 3.0. level_refinement_factor : int @@ -86,7 +86,7 @@ def create_graph_from_datastore( if mesh_node_distance is None: grid_spacing = _estimate_grid_node_spacing(xy) - mesh_node_distance = grid_spacing * grid_mesh_ratio + mesh_node_distance = grid_spacing * grid_mesh_spacing_ratio # Build keyword arguments for the archetype function. # return_components=True is required because wmg.save.to_neural_lam() @@ -146,10 +146,10 @@ def cli(input_args=None): default=None, help="Distance between mesh nodes (in coordinate units). " "If not set, estimated automatically from grid spacing " - "and --grid_mesh_ratio.", + "and --grid_mesh_spacing_ratio.", ) parser.add_argument( - "--grid_mesh_ratio", + "--grid_mesh_spacing_ratio", type=float, default=3.0, help="Ratio of mesh node distance to grid node spacing. " @@ -183,7 +183,7 @@ def cli(input_args=None): output_root_path=os.path.join(datastore.root_path, "graph", args.name), archetype=args.archetype, mesh_node_distance=args.mesh_node_distance, - grid_mesh_ratio=args.grid_mesh_ratio, + grid_mesh_spacing_ratio=args.grid_mesh_spacing_ratio, level_refinement_factor=args.level_refinement_factor, max_num_levels=args.max_num_levels, ) diff --git a/pyproject.toml b/pyproject.toml index 0c6ee78fb..a0ba18685 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "matplotlib>=3.7.0", "plotly>=5.15.0", "torch>=2.3.0", - "torch-geometric>=2.5.3", + "torch-geometric==2.3.1", "parse>=1.20.2", "dataclass-wizard<0.31.0", "mllam-data-prep>=0.5.0", From 76c7ed74dbe86515f12a9b38527065c36ef51bab Mon Sep 17 00:00:00 2001 From: prajwal Date: Wed, 15 Apr 2026 18:40:37 +0530 Subject: [PATCH 7/8] Switch tests to use wmg-based create_graph_from_datastore Migrated all test files that used the old create_graph_from_datastore() from neural_lam.create_graph to use the new wmg-based version from neural_lam.create_graph_with_wmg instead. Changes: - test_datasets.py: Use wmg create_graph_from_datastore with archetype='keisler' - test_clamping.py: Use wmg create_graph_from_datastore with archetype='keisler' - test_plotting.py: Use wmg create_graph_from_datastore with archetype='keisler' - test_training.py: Use wmg create_graph_from_datastore with archetype='keisler' - test_plot_graph.py: Use wmg create_graph_from_datastore with keisler and hierarchical archetypes. Removed multiscale (graphcast) parametrization since the graphcast archetype produces multi-level m2m edges without up/down edges, which is not yet compatible with utils.load_graph(). Graphcast graph creation is separately tested in test_graph_creation.py. --- tests/test_clamping.py | 4 ++-- tests/test_datasets.py | 4 ++-- tests/test_plot_graph.py | 28 +++++++++++++++------------- tests/test_plotting.py | 4 ++-- tests/test_training.py | 4 ++-- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/tests/test_clamping.py b/tests/test_clamping.py index f3f9365d0..b2f737a2c 100644 --- a/tests/test_clamping.py +++ b/tests/test_clamping.py @@ -6,7 +6,7 @@ # First-party from neural_lam import config as nlconfig -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import create_graph_from_datastore from neural_lam.datastore.mdp import MDPDatastore from neural_lam.models.graph_lam import GraphLAM from tests.conftest import init_datastore_example @@ -23,7 +23,7 @@ def test_clamping(): create_graph_from_datastore( datastore=datastore, output_root_path=str(graph_dir_path), - n_max_levels=1, + archetype="keisler", ) class ModelArgs: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index dd863b657..ccdbf2c25 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -9,7 +9,7 @@ # First-party from neural_lam import config as nlconfig -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import create_graph_from_datastore from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM @@ -194,7 +194,7 @@ def _create_graph(): create_graph_from_datastore( datastore=datastore, output_root_path=str(graph_dir_path), - n_max_levels=1, + archetype="keisler", ) if not isinstance(datastore, BaseRegularGridDatastore): diff --git a/tests/test_plot_graph.py b/tests/test_plot_graph.py index 559dac46a..d282aa039 100644 --- a/tests/test_plot_graph.py +++ b/tests/test_plot_graph.py @@ -8,19 +8,24 @@ # First-party from neural_lam import utils -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import create_graph_from_datastore from neural_lam.plot_graph import ( plot_graph, ) from tests.dummy_datastore import DummyDatastore -@pytest.fixture(scope="module", params=["1level", "multiscale", "hierarchical"]) +@pytest.fixture(scope="module", params=["1level", "hierarchical"]) def graph_fixture(request, tmp_path_factory): """Create a graph from a DummyDatastore and load it back. - Parametrized over graph types: 1level (flat), multiscale (flat multi-level), - and hierarchical. + Parametrized over graph types: 1level (flat, keisler archetype) + and hierarchical (multi-level with up/down edges). + + Note: The graphcast archetype is not included here because it produces + multi-level m2m edges without up/down edges, which is not yet + compatible with ``utils.load_graph``. Graphcast graph creation is + tested separately in ``test_graph_creation.py``. Returns ------- @@ -31,14 +36,11 @@ def graph_fixture(request, tmp_path_factory): datastore = DummyDatastore() if graph_name == "hierarchical": - hierarchical = True - n_max_levels = 3 - elif graph_name == "multiscale": - hierarchical = False - n_max_levels = 3 + archetype = "hierarchical" + max_num_levels = 3 elif graph_name == "1level": - hierarchical = False - n_max_levels = 1 + archetype = "keisler" + max_num_levels = None else: raise ValueError(f"Unknown graph_name: {graph_name}") @@ -46,8 +48,8 @@ def graph_fixture(request, tmp_path_factory): create_graph_from_datastore( datastore=datastore, output_root_path=str(graph_dir_path), - hierarchical=hierarchical, - n_max_levels=n_max_levels, + archetype=archetype, + max_num_levels=max_num_levels, ) is_hierarchical, graph_ldict = utils.load_graph( diff --git a/tests/test_plotting.py b/tests/test_plotting.py index ba03857c5..00f5a37a8 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -16,7 +16,7 @@ # First-party from neural_lam import config as nlconfig from neural_lam import vis -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import create_graph_from_datastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataset from tests.conftest import init_datastore_example @@ -226,7 +226,7 @@ class ModelArgs: create_graph_from_datastore( datastore=datastore, output_root_path=str(graph_dir_path), - n_max_levels=1, + archetype="keisler", ) # Create config diff --git a/tests/test_training.py b/tests/test_training.py index d16bd9ed3..f1357eb51 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -10,7 +10,7 @@ # First-party from neural_lam import config as nlconfig -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import create_graph_from_datastore from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.ar_model import ARModel @@ -68,7 +68,7 @@ def run_simple_training(datastore, set_output_std, metrics_watch=None): create_graph_from_datastore( datastore=datastore, output_root_path=str(graph_dir_path), - n_max_levels=1, + archetype="keisler", ) data_module = WeatherDataModule( From 8fe2e3e5e86947a21f20d537e1ff270747cc135e Mon Sep 17 00:00:00 2001 From: prajwal Date: Sat, 18 Apr 2026 11:35:01 +0530 Subject: [PATCH 8/8] Update wmg call: to_neural_lam -> to_torch_tensors_on_disk Follow rename in weather-model-graphs (mllam/weather-model-graphs#123). --- neural_lam/create_graph_with_wmg.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/neural_lam/create_graph_with_wmg.py b/neural_lam/create_graph_with_wmg.py index a5a2b12ca..632f3398a 100644 --- a/neural_lam/create_graph_with_wmg.py +++ b/neural_lam/create_graph_with_wmg.py @@ -89,8 +89,9 @@ def create_graph_from_datastore( mesh_node_distance = grid_spacing * grid_mesh_spacing_ratio # Build keyword arguments for the archetype function. - # return_components=True is required because wmg.save.to_neural_lam() - # expects the graph as separate g2m, m2g and m2m sub-graph components + # return_components=True is required because + # wmg.save.to_torch_tensors_on_disk() expects the graph as + # separate g2m, m2g and m2m sub-graph components # rather than a single merged graph. archetype_kwargs = dict( coords=xy, @@ -108,7 +109,7 @@ def create_graph_from_datastore( hierarchical = archetype == "hierarchical" - wmg.save.to_neural_lam( + wmg.save.to_torch_tensors_on_disk( graph_components=graph_components, output_directory=output_root_path, hierarchical=hierarchical,