Skip to content
15 changes: 14 additions & 1 deletion .github/workflows/install-and-test.yml
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will make a separate PR for this. I think this is some general CI maintenance that should be merged in before this PR.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted — happy to revert the CI changes from this branch if you'd prefer to handle them in a separate PR.

Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,20 @@ jobs:
- name: Determine PyTorch version for neural-lam (using pip dry-run)
if: matrix.package_manager == 'pip'
run: |
TORCH_VERSION=$(python -m pip install --dry-run "." | grep "Would install" | grep -o 'torch-[0-9.]*' | awk -F'-' '{print $2}' | tail -n 1)
# 1. Run pip install in simulation mode for the current directory (".")
# --disable-pip-version-check stops the "[notice] A new release of pip is available" message.
# 2>/dev/null silences any other unexpected errors or warnings by throwing them away.
# 2. grep "Would install": Filter the output to only look at the summary of what pip intends to do.
# 3. grep -o 'torch-[0-9.]*': Extract just the package name and version for torch.
# 4. awk -F'-' '{print $2}': Strip the "torch-" prefix to get just the version number.
# 5. grep -v '^$': Remove any empty lines that might have been produced.
# 6. tail -n 1: Grab only the final result.
TORCH_VERSION=$(python -m pip install --disable-pip-version-check --dry-run "." 2>/dev/null \
| grep "Would install" \
| grep -o 'torch-[0-9.]*' \
| awk -F'-' '{print $2}' \
| grep -v '^$' \
| tail -n 1)
echo "Torch version detected: $TORCH_VERSION"
echo "TORCH_VERSION=$TORCH_VERSION" >> $GITHUB_ENV

Expand Down
9 changes: 9 additions & 0 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard library
import os
import warnings
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from typing import Optional

Expand Down Expand Up @@ -561,6 +562,14 @@ def create_graph_from_datastore(


def cli(input_args=None):
warnings.warn(
"create_graph.py is deprecated and will be removed in a future "
"version. Use create_graph_with_wmg.py instead, which delegates "
"graph creation to weather-model-graphs (wmg). See "
"https://github.com/mllam/neural-lam/issues/384 for details.",
DeprecationWarning,
stacklevel=2,
)
parser = ArgumentParser(
description="Graph generation for neural-lam",
formatter_class=ArgumentDefaultsHelpFormatter,
Expand Down
188 changes: 188 additions & 0 deletions neural_lam/create_graph_with_wmg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Standard library
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser

# Third-party
import numpy as np
import weather_model_graphs as wmg

# Local
from .config import load_config_and_datastore
from .datastore.base import BaseRegularGridDatastore

ARCHETYPE_FUNCTIONS = {
"keisler": wmg.create.archetype.create_keisler_graph,
"graphcast": wmg.create.archetype.create_graphcast_graph,
"hierarchical": wmg.create.archetype.create_oskarsson_hierarchical_graph,
}


def _estimate_mesh_node_distance(xy):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could call this _estimate_grid_node_spacing? And then expose the grid-mesh-spacing ratio as a CLI arg that defaults to 3.0?

Copy link
Copy Markdown
Author

@prajwal-tech07 prajwal-tech07 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed _estimate_mesh_node_distance → _estimate_grid_node_spacing — it now returns only the average grid spacing. The ×3 multiplier is replaced by a new grid_mesh_ratio parameter (default 3.0) exposed as --grid_mesh_ratio on the CLI.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, but it would better to name it --grid_mesh_spacing_ratio I think - just be clear about what it is the ratio between :) what do you think?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--grid_mesh_ratio--grid_mesh_spacing_ratio
Good call — renamed the CLI arg, function parameter, and docstrings throughout to make it clear it's the ratio between mesh-node and grid-node spacing.

"""Estimate a reasonable mesh node distance from grid coordinates.

Uses the average grid spacing to produce a mesh that is roughly 3x
coarser than the grid, similar to the default behaviour of the old
``create_graph.py`` script.

Parameters
----------
xy : np.ndarray
Grid coordinates of shape ``(N, 2)``.

Returns
-------
float
Estimated mesh node distance in coordinate units.
"""
x_range = np.ptp(xy[:, 0])
y_range = np.ptp(xy[:, 1])
n_points = len(xy)
# avg grid spacing ≈ sqrt(area / n_points)
avg_spacing = np.sqrt(x_range * y_range / n_points)
# mesh is ~3x coarser than the grid
return float(avg_spacing * 3)


def create_graph_from_datastore(
datastore,
output_root_path,
archetype="keisler",
mesh_node_distance=None,
level_refinement_factor=3,
max_num_levels=None,
):
"""Create graph using weather-model-graphs and save in neural-lam format.

Parameters
----------
datastore : BaseRegularGridDatastore
Datastore providing grid coordinates.
output_root_path : str
Directory where the .pt graph files will be saved.
archetype : str
Graph archetype to create: ``"keisler"``, ``"graphcast"``, or
``"hierarchical"``.
mesh_node_distance : float or None
Distance between created mesh nodes (in coordinate units). If None,
automatically estimated from the grid spacing.
level_refinement_factor : int
Refinement factor between mesh hierarchy levels. Only used for
``"graphcast"`` and ``"hierarchical"`` archetypes.
max_num_levels : int or None
Maximum number of mesh hierarchy levels. Only used for ``"graphcast"``
and ``"hierarchical"`` archetypes.
"""
if not isinstance(datastore, BaseRegularGridDatastore):
raise NotImplementedError(
"Only graph creation for BaseRegularGridDatastore is supported"
)

if archetype not in ARCHETYPE_FUNCTIONS:
raise ValueError(
f"Unknown archetype '{archetype}'. "
f"Must be one of: {list(ARCHETYPE_FUNCTIONS.keys())}"
)

xy = datastore.get_xy(category="state", stacked=False)

Comment thread
leifdenby marked this conversation as resolved.
Outdated
# wmg expects coords as 2D array of shape (num_nodes, 2), but the
# datastore may return a 3D array of shape (Nx, Ny, 2) when
# stacked=False. Reshape to (N, 2) for wmg.
xy = np.array(xy)
if xy.ndim == 3:
xy = xy.reshape(-1, 2)

if mesh_node_distance is None:
mesh_node_distance = _estimate_mesh_node_distance(xy)

# Build keyword arguments for the archetype function
archetype_kwargs = dict(
coords=xy,
mesh_node_distance=mesh_node_distance,
return_components=True,
Comment thread
leifdenby marked this conversation as resolved.
)

# Only multiscale/hierarchical archetypes accept these parameters
if archetype in ("graphcast", "hierarchical"):
archetype_kwargs["level_refinement_factor"] = level_refinement_factor
archetype_kwargs["max_num_levels"] = max_num_levels

archetype_fn = ARCHETYPE_FUNCTIONS[archetype]
graph_components = archetype_fn(**archetype_kwargs)

hierarchical = archetype == "hierarchical"

wmg.save.to_neural_lam(
graph_components=graph_components,
output_directory=output_root_path,
hierarchical=hierarchical,
)


def cli(input_args=None):
"""Command-line interface for graph creation using weather-model-graphs."""
parser = ArgumentParser(
description="Graph generation for neural-lam using "
"weather-model-graphs (wmg)",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--config_path",
type=str,
help="Path to neural-lam configuration file",
)
parser.add_argument(
"--name",
type=str,
default="multiscale",
help="Name to save graph as (used as subdirectory name)",
)
parser.add_argument(
"--archetype",
type=str,
default="keisler",
choices=["keisler", "graphcast", "hierarchical"],
help="Graph archetype to create",
)
parser.add_argument(
"--mesh_node_distance",
type=float,
default=None,
help="Distance between mesh nodes (in coordinate units). "
"If not set, estimated automatically from grid spacing.",
)
parser.add_argument(
"--level_refinement_factor",
type=int,
default=3,
help="Refinement factor between mesh hierarchy levels "
"(only used for graphcast and hierarchical)",
)
parser.add_argument(
"--max_num_levels",
type=int,
default=None,
help="Maximum number of mesh levels "
"(only used for graphcast and hierarchical)",
)
args = parser.parse_args(input_args)

assert (
args.config_path is not None
), "Specify your config with --config_path"

# Load neural-lam configuration and datastore to use
_, datastore = load_config_and_datastore(config_path=args.config_path)

create_graph_from_datastore(
datastore=datastore,
output_root_path=os.path.join(datastore.root_path, "graph", args.name),
archetype=args.archetype,
mesh_node_distance=args.mesh_node_distance,
level_refinement_factor=args.level_refinement_factor,
max_num_levels=args.max_num_levels,
)


if __name__ == "__main__":
cli()
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,21 @@ dependencies = [
"matplotlib>=3.7.0",
"plotly>=5.15.0",
"torch>=2.3.0",
"torch-geometric==2.3.1",
"torch-geometric>=2.5.3",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a necessary change?

Copy link
Copy Markdown
Author

@prajwal-tech07 prajwal-tech07 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is necessary. The weather-model-graphs[pytorch] extra requires torch-geometric>=2.5.3 in its own dependencies, so we need to allow at least that version here too.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, in that case I think we maybe should reduce required version number on the weather-model-graphs side, since all we are doing with pytorch-geometric is using it to convert the networkx.DiGraph objects to torch.Tensor objects, and we already do that in the current create_graphs.py code in neural-lam with the older version. So how about we change weather-model-graphs to require torch-geometric==2.3.1 so the two are in sync? Then we don't have to change anything on the neural-lam side

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right — wmg only uses from_networkx which works fine with 2.3.1. I've also pushed a commit to the wmg PR branch (prajwal-tech07/weather-model-graphs@d0c693d) pinning it to ==2.3.1 there too, so both repos stay in sync.

"parse>=1.20.2",
"dataclass-wizard<0.31.0",
"mllam-data-prep>=0.5.0",
"mlflow>=2.16.2",
"boto3>=1.35.32",
"nvidia-ml-py>=13.580.82",
"pillow>=9.0.0",
"weather-model-graphs",
]
requires-python = ">=3.10"

[project.scripts]
create_graph_with_wmg = "neural_lam.create_graph_with_wmg:cli"
Comment thread
leifdenby marked this conversation as resolved.

[dependency-groups]
dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"]

Expand Down Expand Up @@ -130,3 +134,6 @@ exclude = [
".venv/",
"venv/",
]

[tool.uv.sources]
weather-model-graphs = { git = "https://github.com/prajwal-tech07/weather-model-graphs", rev = "issue-384/to-neural-lam" }
103 changes: 103 additions & 0 deletions tests/test_graph_creation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard library
import tempfile
import warnings
from pathlib import Path

# Third-party
Expand All @@ -8,6 +9,9 @@

# First-party
from neural_lam.create_graph import create_graph_from_datastore
from neural_lam.create_graph_with_wmg import (
create_graph_from_datastore as wmg_create_graph_from_datastore,
)
from neural_lam.datastore import DATASTORES
from neural_lam.datastore.base import BaseRegularGridDatastore
from tests.conftest import init_datastore_example
Expand Down Expand Up @@ -117,3 +121,102 @@ def test_graph_creation(datastore_name, graph_name):
assert r.shape[0] == 2 # adjacency matrix uses two rows
elif file_id.endswith("_features"):
assert r.shape[1] == d_features


@pytest.mark.parametrize("archetype", ["keisler", "graphcast", "hierarchical"])
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_wmg_graph_creation(datastore_name, archetype):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, but we could actually use the testing that #323 introduces. I should try and get that finished so that we can merge both in together :)

Copy link
Copy Markdown
Author

@prajwal-tech07 prajwal-tech07 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the tests are working well as-is! Happy to adapt them to the testing infrastructure from #323 once that's ready — makes total sense to merge them together. Let me know if there's anything I can do to help with #323!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, how about we say that depending on whether I finish #323 soon then we either a) merge this in with the tests have you have them implemented already or b) merge #323 in first and then adapting the testing here to use the testing being introduced in #323?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a great plan — either way works for me. Happy to adapt the tests to #323's infrastructure once it lands, or merge as-is if this PR is ready

"""Check that graph creation via weather-model-graphs produces the
expected .pt files with correct shapes and types."""
datastore = init_datastore_example(datastore_name)

if not isinstance(datastore, BaseRegularGridDatastore):
pytest.skip(
f"Skipping test for {datastore_name} as it is not a regular "
"grid datastore."
)

hierarchical = archetype == "hierarchical"

required_graph_files = [
"m2m_edge_index.pt",
"g2m_edge_index.pt",
"m2g_edge_index.pt",
"m2m_features.pt",
"g2m_features.pt",
"m2g_features.pt",
"mesh_features.pt",
]
if hierarchical:
required_graph_files.extend(
[
"mesh_up_edge_index.pt",
"mesh_down_edge_index.pt",
"mesh_up_features.pt",
"mesh_down_features.pt",
]
)

d_features = 3
d_mesh_static = 2

with tempfile.TemporaryDirectory() as tmpdir:
graph_dir_path = Path(tmpdir) / "graph" / archetype

wmg_create_graph_from_datastore(
datastore=datastore,
output_root_path=str(graph_dir_path),
archetype=archetype,
)

assert graph_dir_path.exists()

# check that all the required files are present
for file_name in required_graph_files:
assert (graph_dir_path / file_name).exists()

# try to load each and ensure they have the right shape
for file_name in required_graph_files:
file_id = Path(file_name).stem
result = torch.load(graph_dir_path / file_name, weights_only=True)

if file_id.startswith("g2m") or file_id.startswith("m2g"):
assert isinstance(result, torch.Tensor)

if file_id.endswith("_index"):
assert result.shape[0] == 2
elif file_id.endswith("_features"):
assert result.shape[1] == d_features

elif file_id.startswith("m2m") or file_id.startswith("mesh"):
assert isinstance(result, list)

for r in result:
assert isinstance(r, torch.Tensor)

if file_id == "mesh_features":
assert r.shape[1] == d_mesh_static
elif file_id.endswith("_index"):
assert r.shape[0] == 2
elif file_id.endswith("_features"):
assert r.shape[1] == d_features


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_old_create_graph_deprecation_warning(datastore_name):
"""Check that the old create_graph CLI emits a deprecation warning."""
# First-party
from neural_lam.create_graph import cli as old_cli

with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
try:
old_cli(["--config_path", "nonexistent.yaml"])
except Exception:
pass # We only care about the warning, not the error

deprecation_warnings = [
x for x in w if issubclass(x.category, DeprecationWarning)
]
assert len(deprecation_warnings) >= 1
assert "create_graph_with_wmg" in str(deprecation_warnings[0].message)
Loading