Skip to content
15 changes: 14 additions & 1 deletion .github/workflows/install-and-test.yml
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will make a separate PR for this. I think this is some general CI maintenance that should be merged in before this PR.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted — happy to revert the CI changes from this branch if you'd prefer to handle them in a separate PR.

Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,20 @@ jobs:
- name: Determine PyTorch version for neural-lam (using pip dry-run)
if: matrix.package_manager == 'pip'
run: |
TORCH_VERSION=$(python -m pip install --dry-run "." | grep "Would install" | grep -o 'torch-[0-9.]*' | awk -F'-' '{print $2}' | tail -n 1)
# 1. Run pip install in simulation mode for the current directory (".")
# --disable-pip-version-check stops the "[notice] A new release of pip is available" message.
# 2>/dev/null silences any other unexpected errors or warnings by throwing them away.
# 2. grep "Would install": Filter the output to only look at the summary of what pip intends to do.
# 3. grep -o 'torch-[0-9.]*': Extract just the package name and version for torch.
# 4. awk -F'-' '{print $2}': Strip the "torch-" prefix to get just the version number.
# 5. grep -v '^$': Remove any empty lines that might have been produced.
# 6. tail -n 1: Grab only the final result.
TORCH_VERSION=$(python -m pip install --disable-pip-version-check --dry-run "." 2>/dev/null \
| grep "Would install" \
| grep -o 'torch-[0-9.]*' \
| awk -F'-' '{print $2}' \
| grep -v '^$' \
| tail -n 1)
echo "Torch version detected: $TORCH_VERSION"
echo "TORCH_VERSION=$TORCH_VERSION" >> $GITHUB_ENV

Expand Down
23 changes: 23 additions & 0 deletions README.md
Comment thread
leifdenby marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ python -m neural_lam.datastore.npyfilesmeps.compute_standardization_stats <path-

### Graph creation

> **Note:** The `create_graph` command below is deprecated and will be removed
> in a future release. Please use `create_graph_with_wmg` (see below) instead.

Run `python -m neural_lam.create_graph` with suitable options to generate the graph you want to use (see `python -m neural_lam.create_graph --help` for a list of options).
The graphs used for the different models in the [paper](#graph-based-neural-weather-prediction-for-limited-area-modeling) can be created as:

Expand All @@ -399,6 +402,26 @@ The graphs used for the different models in the [paper](#graph-based-neural-weat

The graph-related files are stored in a directory called `graphs`.

### Graph creation with weather-model-graphs

The recommended way to create graphs is with the `create_graph_with_wmg`
command, which delegates graph construction to
[weather-model-graphs](https://github.com/mllam/weather-model-graphs):

```bash
python -m neural_lam.create_graph_with_wmg --config_path <neural-lam-config-path> --archetype <archetype>
```

Available archetypes:

* **keisler** (default): `python -m neural_lam.create_graph_with_wmg --config_path <neural-lam-config-path> --archetype keisler`
* **graphcast**: `python -m neural_lam.create_graph_with_wmg --config_path <neural-lam-config-path> --archetype graphcast`
* **hierarchical**: `python -m neural_lam.create_graph_with_wmg --config_path <neural-lam-config-path> --archetype hierarchical`

Run `python -m neural_lam.create_graph_with_wmg --help` for the full list of
options (e.g. `--mesh_node_distance`, `--grid_mesh_ratio`,
`--level_refinement_factor`, `--max_num_levels`).

## Logging your experiments

### Weights & Biases Integration
Expand Down
9 changes: 9 additions & 0 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard library
import os
import warnings
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
from typing import Optional

Expand Down Expand Up @@ -561,6 +562,14 @@ def create_graph_from_datastore(


def cli(input_args=None):
warnings.warn(
"create_graph.py is deprecated and will be removed in a future "
"version. Use create_graph_with_wmg.py instead, which delegates "
"graph creation to weather-model-graphs (wmg). See "
"https://github.com/mllam/neural-lam/issues/384 for details.",
DeprecationWarning,
stacklevel=2,
)
parser = ArgumentParser(
description="Graph generation for neural-lam",
formatter_class=ArgumentDefaultsHelpFormatter,
Expand Down
193 changes: 193 additions & 0 deletions neural_lam/create_graph_with_wmg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Standard library
import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser

# Third-party
import numpy as np
import weather_model_graphs as wmg

# Local
from .config import load_config_and_datastore
from .datastore.base import BaseRegularGridDatastore

ARCHETYPE_FUNCTIONS = {
"keisler": wmg.create.archetype.create_keisler_graph,
"graphcast": wmg.create.archetype.create_graphcast_graph,
"hierarchical": wmg.create.archetype.create_oskarsson_hierarchical_graph,
}


def _estimate_grid_node_spacing(xy):
"""Estimate the average grid node spacing from grid coordinates.

Parameters
----------
xy : np.ndarray
Grid coordinates of shape ``(N, 2)``.

Returns
-------
float
Estimated average grid node spacing in coordinate units.
"""
x_range = np.ptp(xy[:, 0])
y_range = np.ptp(xy[:, 1])
n_points = len(xy)
# avg grid spacing ≈ sqrt(area / n_points)
return float(np.sqrt(x_range * y_range / n_points))


def create_graph_from_datastore(
datastore,
output_root_path,
archetype="keisler",
mesh_node_distance=None,
grid_mesh_ratio=3.0,
level_refinement_factor=3,
max_num_levels=None,
):
"""Create graph using weather-model-graphs and save in neural-lam format.

Parameters
----------
datastore : BaseRegularGridDatastore
Datastore providing grid coordinates.
output_root_path : str
Directory where the .pt graph files will be saved.
archetype : str
Graph archetype to create: ``"keisler"``, ``"graphcast"``, or
``"hierarchical"``.
mesh_node_distance : float or None
Distance between created mesh nodes (in coordinate units). If None,
automatically estimated as ``grid_mesh_ratio * grid_spacing``.
grid_mesh_ratio : float
Ratio of mesh node distance to grid node spacing. Only used when
``mesh_node_distance`` is None. Default is 3.0.
level_refinement_factor : int
Refinement factor between mesh hierarchy levels. Only used for
``"graphcast"`` and ``"hierarchical"`` archetypes.
max_num_levels : int or None
Maximum number of mesh hierarchy levels. Only used for ``"graphcast"``
and ``"hierarchical"`` archetypes.
"""
if not isinstance(datastore, BaseRegularGridDatastore):
raise NotImplementedError(
"Only graph creation for BaseRegularGridDatastore is supported"
)

if archetype not in ARCHETYPE_FUNCTIONS:
raise ValueError(
f"Unknown archetype '{archetype}'. "
f"Must be one of: {list(ARCHETYPE_FUNCTIONS.keys())}"
)

xy = datastore.get_xy(category="state", stacked=True)
xy = np.array(xy)

if mesh_node_distance is None:
grid_spacing = _estimate_grid_node_spacing(xy)
mesh_node_distance = grid_spacing * grid_mesh_ratio

# Build keyword arguments for the archetype function.
# return_components=True is required because wmg.save.to_neural_lam()
# expects the graph as separate g2m, m2g and m2m sub-graph components
# rather than a single merged graph.
archetype_kwargs = dict(
coords=xy,
mesh_node_distance=mesh_node_distance,
return_components=True,
Comment thread
leifdenby marked this conversation as resolved.
)

# Only multiscale/hierarchical archetypes accept these parameters
if archetype in ("graphcast", "hierarchical"):
archetype_kwargs["level_refinement_factor"] = level_refinement_factor
archetype_kwargs["max_num_levels"] = max_num_levels

archetype_fn = ARCHETYPE_FUNCTIONS[archetype]
graph_components = archetype_fn(**archetype_kwargs)

hierarchical = archetype == "hierarchical"

wmg.save.to_neural_lam(
graph_components=graph_components,
output_directory=output_root_path,
hierarchical=hierarchical,
)


def cli(input_args=None):
"""Command-line interface for graph creation using weather-model-graphs."""
parser = ArgumentParser(
description="Graph generation for neural-lam using "
"weather-model-graphs (wmg)",
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--config_path",
type=str,
help="Path to neural-lam configuration file",
)
parser.add_argument(
"--name",
type=str,
default="multiscale",
help="Name to save graph as (used as subdirectory name)",
)
parser.add_argument(
"--archetype",
type=str,
default="keisler",
choices=["keisler", "graphcast", "hierarchical"],
help="Graph archetype to create",
)
parser.add_argument(
"--mesh_node_distance",
type=float,
default=None,
help="Distance between mesh nodes (in coordinate units). "
"If not set, estimated automatically from grid spacing "
"and --grid_mesh_ratio.",
)
parser.add_argument(
"--grid_mesh_ratio",
type=float,
default=3.0,
help="Ratio of mesh node distance to grid node spacing. "
"Only used when --mesh_node_distance is not set.",
)
parser.add_argument(
"--level_refinement_factor",
type=int,
default=3,
help="Refinement factor between mesh hierarchy levels "
"(only used for graphcast and hierarchical)",
)
parser.add_argument(
"--max_num_levels",
type=int,
default=None,
help="Maximum number of mesh levels "
"(only used for graphcast and hierarchical)",
)
args = parser.parse_args(input_args)

assert (
args.config_path is not None
), "Specify your config with --config_path"

# Load neural-lam configuration and datastore to use
_, datastore = load_config_and_datastore(config_path=args.config_path)

create_graph_from_datastore(
datastore=datastore,
output_root_path=os.path.join(datastore.root_path, "graph", args.name),
archetype=args.archetype,
mesh_node_distance=args.mesh_node_distance,
grid_mesh_ratio=args.grid_mesh_ratio,
level_refinement_factor=args.level_refinement_factor,
max_num_levels=args.max_num_levels,
)


if __name__ == "__main__":
cli()
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,21 @@ dependencies = [
"matplotlib>=3.7.0",
"plotly>=5.15.0",
"torch>=2.3.0",
"torch-geometric==2.3.1",
"torch-geometric>=2.5.3",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a necessary change?

Copy link
Copy Markdown
Author

@prajwal-tech07 prajwal-tech07 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is necessary. The weather-model-graphs[pytorch] extra requires torch-geometric>=2.5.3 in its own dependencies, so we need to allow at least that version here too.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, in that case I think we maybe should reduce required version number on the weather-model-graphs side, since all we are doing with pytorch-geometric is using it to convert the networkx.DiGraph objects to torch.Tensor objects, and we already do that in the current create_graphs.py code in neural-lam with the older version. So how about we change weather-model-graphs to require torch-geometric==2.3.1 so the two are in sync? Then we don't have to change anything on the neural-lam side

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right — wmg only uses from_networkx which works fine with 2.3.1. I've also pushed a commit to the wmg PR branch (prajwal-tech07/weather-model-graphs@d0c693d) pinning it to ==2.3.1 there too, so both repos stay in sync.

"parse>=1.20.2",
"dataclass-wizard<0.31.0",
"mllam-data-prep>=0.5.0",
"mlflow>=2.16.2",
"boto3>=1.35.32",
"nvidia-ml-py>=13.580.82",
"pillow>=9.0.0",
"weather-model-graphs",
]
requires-python = ">=3.10"

[project.scripts]
create_graph_with_wmg = "neural_lam.create_graph_with_wmg:cli"
Comment thread
leifdenby marked this conversation as resolved.

[dependency-groups]
dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"]

Expand Down Expand Up @@ -130,3 +134,6 @@ exclude = [
".venv/",
"venv/",
]

[tool.uv.sources]
weather-model-graphs = { git = "https://github.com/prajwal-tech07/weather-model-graphs", rev = "issue-384/to-neural-lam" }
Loading
Loading