diff --git a/.github/workflows/install-and-test.yml b/.github/workflows/install-and-test.yml index 8e6be806f..8cb9b3c08 100644 --- a/.github/workflows/install-and-test.yml +++ b/.github/workflows/install-and-test.yml @@ -62,7 +62,20 @@ jobs: - name: Determine PyTorch version for neural-lam (using pip dry-run) if: matrix.package_manager == 'pip' run: | - TORCH_VERSION=$(python -m pip install --dry-run "." | grep "Would install" | grep -o 'torch-[0-9.]*' | awk -F'-' '{print $2}' | tail -n 1) + # 1. Run pip install in simulation mode for the current directory (".") + # --disable-pip-version-check stops the "[notice] A new release of pip is available" message. + # 2>/dev/null silences any other unexpected errors or warnings by throwing them away. + # 2. grep "Would install": Filter the output to only look at the summary of what pip intends to do. + # 3. grep -o 'torch-[0-9.]*': Extract just the package name and version for torch. + # 4. awk -F'-' '{print $2}': Strip the "torch-" prefix to get just the version number. + # 5. grep -v '^$': Remove any empty lines that might have been produced. + # 6. tail -n 1: Grab only the final result. + TORCH_VERSION=$(python -m pip install --disable-pip-version-check --dry-run "." 2>/dev/null \ + | grep "Would install" \ + | grep -o 'torch-[0-9.]*' \ + | awk -F'-' '{print $2}' \ + | grep -v '^$' \ + | tail -n 1) echo "Torch version detected: $TORCH_VERSION" echo "TORCH_VERSION=$TORCH_VERSION" >> $GITHUB_ENV diff --git a/README.md b/README.md index 17e269841..c4fa0b961 100644 --- a/README.md +++ b/README.md @@ -390,6 +390,9 @@ python -m neural_lam.datastore.npyfilesmeps.compute_standardization_stats **Note:** The `create_graph` command below is deprecated and will be removed +> in a future release. Please use `create_graph_with_wmg` (see below) instead. + Run `python -m neural_lam.create_graph` with suitable options to generate the graph you want to use (see `python -m neural_lam.create_graph --help` for a list of options). The graphs used for the different models in the [paper](#graph-based-neural-weather-prediction-for-limited-area-modeling) can be created as: @@ -399,6 +402,26 @@ The graphs used for the different models in the [paper](#graph-based-neural-weat The graph-related files are stored in a directory called `graphs`. +### Graph creation with weather-model-graphs + +The recommended way to create graphs is with the `create_graph_with_wmg` +command, which delegates graph construction to +[weather-model-graphs](https://github.com/mllam/weather-model-graphs): + +```bash +python -m neural_lam.create_graph_with_wmg --config_path --archetype +``` + +Available archetypes: + +* **keisler** (default): `python -m neural_lam.create_graph_with_wmg --config_path --archetype keisler` +* **graphcast**: `python -m neural_lam.create_graph_with_wmg --config_path --archetype graphcast` +* **hierarchical**: `python -m neural_lam.create_graph_with_wmg --config_path --archetype hierarchical` + +Run `python -m neural_lam.create_graph_with_wmg --help` for the full list of +options (e.g. `--mesh_node_distance`, `--grid_mesh_spacing_ratio`, +`--level_refinement_factor`, `--max_num_levels`). + ## Logging your experiments ### Weights & Biases Integration diff --git a/neural_lam/create_graph.py b/neural_lam/create_graph.py index b24b96c23..03a9efcf7 100644 --- a/neural_lam/create_graph.py +++ b/neural_lam/create_graph.py @@ -1,5 +1,6 @@ # Standard library import os +import warnings from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from typing import Optional @@ -561,6 +562,14 @@ def create_graph_from_datastore( def cli(input_args=None): + warnings.warn( + "create_graph.py is deprecated and will be removed in a future " + "version. Use create_graph_with_wmg.py instead, which delegates " + "graph creation to weather-model-graphs (wmg). See " + "https://github.com/mllam/neural-lam/issues/384 for details.", + DeprecationWarning, + stacklevel=2, + ) parser = ArgumentParser( description="Graph generation for neural-lam", formatter_class=ArgumentDefaultsHelpFormatter, diff --git a/neural_lam/create_graph_with_wmg.py b/neural_lam/create_graph_with_wmg.py new file mode 100644 index 000000000..632f3398a --- /dev/null +++ b/neural_lam/create_graph_with_wmg.py @@ -0,0 +1,194 @@ +# Standard library +import os +from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser + +# Third-party +import numpy as np +import weather_model_graphs as wmg + +# Local +from .config import load_config_and_datastore +from .datastore.base import BaseRegularGridDatastore + +ARCHETYPE_FUNCTIONS = { + "keisler": wmg.create.archetype.create_keisler_graph, + "graphcast": wmg.create.archetype.create_graphcast_graph, + "hierarchical": wmg.create.archetype.create_oskarsson_hierarchical_graph, +} + + +def _estimate_grid_node_spacing(xy): + """Estimate the average grid node spacing from grid coordinates. + + Parameters + ---------- + xy : np.ndarray + Grid coordinates of shape ``(N, 2)``. + + Returns + ------- + float + Estimated average grid node spacing in coordinate units. + """ + x_range = np.ptp(xy[:, 0]) + y_range = np.ptp(xy[:, 1]) + n_points = len(xy) + # avg grid spacing ≈ sqrt(area / n_points) + return float(np.sqrt(x_range * y_range / n_points)) + + +def create_graph_from_datastore( + datastore, + output_root_path, + archetype="keisler", + mesh_node_distance=None, + grid_mesh_spacing_ratio=3.0, + level_refinement_factor=3, + max_num_levels=None, +): + """Create graph using weather-model-graphs and save in neural-lam format. + + Parameters + ---------- + datastore : BaseRegularGridDatastore + Datastore providing grid coordinates. + output_root_path : str + Directory where the .pt graph files will be saved. + archetype : str + Graph archetype to create: ``"keisler"``, ``"graphcast"``, or + ``"hierarchical"``. + mesh_node_distance : float or None + Distance between created mesh nodes (in coordinate units). If None, + automatically estimated as ``grid_mesh_spacing_ratio * grid_spacing``. + grid_mesh_spacing_ratio : float + Ratio of mesh node distance to grid node spacing. Only used when + ``mesh_node_distance`` is None. Default is 3.0. + level_refinement_factor : int + Refinement factor between mesh hierarchy levels. Only used for + ``"graphcast"`` and ``"hierarchical"`` archetypes. + max_num_levels : int or None + Maximum number of mesh hierarchy levels. Only used for ``"graphcast"`` + and ``"hierarchical"`` archetypes. + """ + if not isinstance(datastore, BaseRegularGridDatastore): + raise NotImplementedError( + "Only graph creation for BaseRegularGridDatastore is supported" + ) + + if archetype not in ARCHETYPE_FUNCTIONS: + raise ValueError( + f"Unknown archetype '{archetype}'. " + f"Must be one of: {list(ARCHETYPE_FUNCTIONS.keys())}" + ) + + xy = datastore.get_xy(category="state", stacked=True) + xy = np.array(xy) + + if mesh_node_distance is None: + grid_spacing = _estimate_grid_node_spacing(xy) + mesh_node_distance = grid_spacing * grid_mesh_spacing_ratio + + # Build keyword arguments for the archetype function. + # return_components=True is required because + # wmg.save.to_torch_tensors_on_disk() expects the graph as + # separate g2m, m2g and m2m sub-graph components + # rather than a single merged graph. + archetype_kwargs = dict( + coords=xy, + mesh_node_distance=mesh_node_distance, + return_components=True, + ) + + # Only multiscale/hierarchical archetypes accept these parameters + if archetype in ("graphcast", "hierarchical"): + archetype_kwargs["level_refinement_factor"] = level_refinement_factor + archetype_kwargs["max_num_levels"] = max_num_levels + + archetype_fn = ARCHETYPE_FUNCTIONS[archetype] + graph_components = archetype_fn(**archetype_kwargs) + + hierarchical = archetype == "hierarchical" + + wmg.save.to_torch_tensors_on_disk( + graph_components=graph_components, + output_directory=output_root_path, + hierarchical=hierarchical, + ) + + +def cli(input_args=None): + """Command-line interface for graph creation using weather-model-graphs.""" + parser = ArgumentParser( + description="Graph generation for neural-lam using " + "weather-model-graphs (wmg)", + formatter_class=ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--config_path", + type=str, + help="Path to neural-lam configuration file", + ) + parser.add_argument( + "--name", + type=str, + default="multiscale", + help="Name to save graph as (used as subdirectory name)", + ) + parser.add_argument( + "--archetype", + type=str, + default="keisler", + choices=["keisler", "graphcast", "hierarchical"], + help="Graph archetype to create", + ) + parser.add_argument( + "--mesh_node_distance", + type=float, + default=None, + help="Distance between mesh nodes (in coordinate units). " + "If not set, estimated automatically from grid spacing " + "and --grid_mesh_spacing_ratio.", + ) + parser.add_argument( + "--grid_mesh_spacing_ratio", + type=float, + default=3.0, + help="Ratio of mesh node distance to grid node spacing. " + "Only used when --mesh_node_distance is not set.", + ) + parser.add_argument( + "--level_refinement_factor", + type=int, + default=3, + help="Refinement factor between mesh hierarchy levels " + "(only used for graphcast and hierarchical)", + ) + parser.add_argument( + "--max_num_levels", + type=int, + default=None, + help="Maximum number of mesh levels " + "(only used for graphcast and hierarchical)", + ) + args = parser.parse_args(input_args) + + assert ( + args.config_path is not None + ), "Specify your config with --config_path" + + # Load neural-lam configuration and datastore to use + _, datastore = load_config_and_datastore(config_path=args.config_path) + + create_graph_from_datastore( + datastore=datastore, + output_root_path=os.path.join(datastore.root_path, "graph", args.name), + archetype=args.archetype, + mesh_node_distance=args.mesh_node_distance, + grid_mesh_spacing_ratio=args.grid_mesh_spacing_ratio, + level_refinement_factor=args.level_refinement_factor, + max_num_levels=args.max_num_levels, + ) + + +if __name__ == "__main__": + cli() diff --git a/pyproject.toml b/pyproject.toml index bccec0828..a0ba18685 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,13 @@ dependencies = [ "boto3>=1.35.32", "nvidia-ml-py>=13.580.82", "pillow>=9.0.0", + "weather-model-graphs", ] requires-python = ">=3.10" +[project.scripts] +create_graph_with_wmg = "neural_lam.create_graph_with_wmg:cli" + [dependency-groups] dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"] @@ -130,3 +134,6 @@ exclude = [ ".venv/", "venv/", ] + +[tool.uv.sources] +weather-model-graphs = { git = "https://github.com/prajwal-tech07/weather-model-graphs", rev = "issue-384/to-neural-lam" } diff --git a/tests/test_clamping.py b/tests/test_clamping.py index f3f9365d0..b2f737a2c 100644 --- a/tests/test_clamping.py +++ b/tests/test_clamping.py @@ -6,7 +6,7 @@ # First-party from neural_lam import config as nlconfig -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import create_graph_from_datastore from neural_lam.datastore.mdp import MDPDatastore from neural_lam.models.graph_lam import GraphLAM from tests.conftest import init_datastore_example @@ -23,7 +23,7 @@ def test_clamping(): create_graph_from_datastore( datastore=datastore, output_root_path=str(graph_dir_path), - n_max_levels=1, + archetype="keisler", ) class ModelArgs: diff --git a/tests/test_datasets.py b/tests/test_datasets.py index dd863b657..ccdbf2c25 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -9,7 +9,7 @@ # First-party from neural_lam import config as nlconfig -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import create_graph_from_datastore from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.graph_lam import GraphLAM @@ -194,7 +194,7 @@ def _create_graph(): create_graph_from_datastore( datastore=datastore, output_root_path=str(graph_dir_path), - n_max_levels=1, + archetype="keisler", ) if not isinstance(datastore, BaseRegularGridDatastore): diff --git a/tests/test_graph_creation.py b/tests/test_graph_creation.py index 93a7a55f4..76a4f7bdc 100644 --- a/tests/test_graph_creation.py +++ b/tests/test_graph_creation.py @@ -1,5 +1,6 @@ # Standard library import tempfile +import warnings from pathlib import Path # Third-party @@ -8,6 +9,9 @@ # First-party from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import ( + create_graph_from_datastore as wmg_create_graph_from_datastore, +) from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore from tests.conftest import init_datastore_example @@ -117,3 +121,102 @@ def test_graph_creation(datastore_name, graph_name): assert r.shape[0] == 2 # adjacency matrix uses two rows elif file_id.endswith("_features"): assert r.shape[1] == d_features + + +@pytest.mark.parametrize("archetype", ["keisler", "graphcast", "hierarchical"]) +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_wmg_graph_creation(datastore_name, archetype): + """Check that graph creation via weather-model-graphs produces the + expected .pt files with correct shapes and types.""" + datastore = init_datastore_example(datastore_name) + + if not isinstance(datastore, BaseRegularGridDatastore): + pytest.skip( + f"Skipping test for {datastore_name} as it is not a regular " + "grid datastore." + ) + + hierarchical = archetype == "hierarchical" + + required_graph_files = [ + "m2m_edge_index.pt", + "g2m_edge_index.pt", + "m2g_edge_index.pt", + "m2m_features.pt", + "g2m_features.pt", + "m2g_features.pt", + "mesh_features.pt", + ] + if hierarchical: + required_graph_files.extend( + [ + "mesh_up_edge_index.pt", + "mesh_down_edge_index.pt", + "mesh_up_features.pt", + "mesh_down_features.pt", + ] + ) + + d_features = 3 + d_mesh_static = 2 + + with tempfile.TemporaryDirectory() as tmpdir: + graph_dir_path = Path(tmpdir) / "graph" / archetype + + wmg_create_graph_from_datastore( + datastore=datastore, + output_root_path=str(graph_dir_path), + archetype=archetype, + ) + + assert graph_dir_path.exists() + + # check that all the required files are present + for file_name in required_graph_files: + assert (graph_dir_path / file_name).exists() + + # try to load each and ensure they have the right shape + for file_name in required_graph_files: + file_id = Path(file_name).stem + result = torch.load(graph_dir_path / file_name, weights_only=True) + + if file_id.startswith("g2m") or file_id.startswith("m2g"): + assert isinstance(result, torch.Tensor) + + if file_id.endswith("_index"): + assert result.shape[0] == 2 + elif file_id.endswith("_features"): + assert result.shape[1] == d_features + + elif file_id.startswith("m2m") or file_id.startswith("mesh"): + assert isinstance(result, list) + + for r in result: + assert isinstance(r, torch.Tensor) + + if file_id == "mesh_features": + assert r.shape[1] == d_mesh_static + elif file_id.endswith("_index"): + assert r.shape[0] == 2 + elif file_id.endswith("_features"): + assert r.shape[1] == d_features + + +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_old_create_graph_deprecation_warning(datastore_name): + """Check that the old create_graph CLI emits a deprecation warning.""" + # First-party + from neural_lam.create_graph import cli as old_cli + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + try: + old_cli(["--config_path", "nonexistent.yaml"]) + except Exception: + pass # We only care about the warning, not the error + + deprecation_warnings = [ + x for x in w if issubclass(x.category, DeprecationWarning) + ] + assert len(deprecation_warnings) >= 1 + assert "create_graph_with_wmg" in str(deprecation_warnings[0].message) diff --git a/tests/test_plot_graph.py b/tests/test_plot_graph.py index 559dac46a..d282aa039 100644 --- a/tests/test_plot_graph.py +++ b/tests/test_plot_graph.py @@ -8,19 +8,24 @@ # First-party from neural_lam import utils -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import create_graph_from_datastore from neural_lam.plot_graph import ( plot_graph, ) from tests.dummy_datastore import DummyDatastore -@pytest.fixture(scope="module", params=["1level", "multiscale", "hierarchical"]) +@pytest.fixture(scope="module", params=["1level", "hierarchical"]) def graph_fixture(request, tmp_path_factory): """Create a graph from a DummyDatastore and load it back. - Parametrized over graph types: 1level (flat), multiscale (flat multi-level), - and hierarchical. + Parametrized over graph types: 1level (flat, keisler archetype) + and hierarchical (multi-level with up/down edges). + + Note: The graphcast archetype is not included here because it produces + multi-level m2m edges without up/down edges, which is not yet + compatible with ``utils.load_graph``. Graphcast graph creation is + tested separately in ``test_graph_creation.py``. Returns ------- @@ -31,14 +36,11 @@ def graph_fixture(request, tmp_path_factory): datastore = DummyDatastore() if graph_name == "hierarchical": - hierarchical = True - n_max_levels = 3 - elif graph_name == "multiscale": - hierarchical = False - n_max_levels = 3 + archetype = "hierarchical" + max_num_levels = 3 elif graph_name == "1level": - hierarchical = False - n_max_levels = 1 + archetype = "keisler" + max_num_levels = None else: raise ValueError(f"Unknown graph_name: {graph_name}") @@ -46,8 +48,8 @@ def graph_fixture(request, tmp_path_factory): create_graph_from_datastore( datastore=datastore, output_root_path=str(graph_dir_path), - hierarchical=hierarchical, - n_max_levels=n_max_levels, + archetype=archetype, + max_num_levels=max_num_levels, ) is_hierarchical, graph_ldict = utils.load_graph( diff --git a/tests/test_plotting.py b/tests/test_plotting.py index ba03857c5..00f5a37a8 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -16,7 +16,7 @@ # First-party from neural_lam import config as nlconfig from neural_lam import vis -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import create_graph_from_datastore from neural_lam.models.graph_lam import GraphLAM from neural_lam.weather_dataset import WeatherDataset from tests.conftest import init_datastore_example @@ -226,7 +226,7 @@ class ModelArgs: create_graph_from_datastore( datastore=datastore, output_root_path=str(graph_dir_path), - n_max_levels=1, + archetype="keisler", ) # Create config diff --git a/tests/test_training.py b/tests/test_training.py index d16bd9ed3..f1357eb51 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -10,7 +10,7 @@ # First-party from neural_lam import config as nlconfig -from neural_lam.create_graph import create_graph_from_datastore +from neural_lam.create_graph_with_wmg import create_graph_from_datastore from neural_lam.datastore import DATASTORES from neural_lam.datastore.base import BaseRegularGridDatastore from neural_lam.models.ar_model import ARModel @@ -68,7 +68,7 @@ def run_simple_training(datastore, set_output_std, metrics_watch=None): create_graph_from_datastore( datastore=datastore, output_root_path=str(graph_dir_path), - n_max_levels=1, + archetype="keisler", ) data_module = WeatherDataModule(