-
Notifications
You must be signed in to change notification settings - Fork 263
Add create_graph_with_wmg.py CLI using weather-model-graphs #596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
f51beba
dc77271
2500ade
d6da819
2c679e7
8a690e4
76c7ed7
8fe2e3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| # Standard library | ||
| import os | ||
| from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser | ||
|
|
||
| # Third-party | ||
| import numpy as np | ||
| import weather_model_graphs as wmg | ||
|
|
||
| # Local | ||
| from .config import load_config_and_datastore | ||
| from .datastore.base import BaseRegularGridDatastore | ||
|
|
||
| ARCHETYPE_FUNCTIONS = { | ||
| "keisler": wmg.create.archetype.create_keisler_graph, | ||
| "graphcast": wmg.create.archetype.create_graphcast_graph, | ||
| "hierarchical": wmg.create.archetype.create_oskarsson_hierarchical_graph, | ||
| } | ||
|
|
||
|
|
||
| def _estimate_mesh_node_distance(xy): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we could call this
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed _estimate_mesh_node_distance →
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks great, but it would better to name it
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| """Estimate a reasonable mesh node distance from grid coordinates. | ||
|
|
||
| Uses the average grid spacing to produce a mesh that is roughly 3x | ||
| coarser than the grid, similar to the default behaviour of the old | ||
| ``create_graph.py`` script. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| xy : np.ndarray | ||
| Grid coordinates of shape ``(N, 2)``. | ||
|
|
||
| Returns | ||
| ------- | ||
| float | ||
| Estimated mesh node distance in coordinate units. | ||
| """ | ||
| x_range = np.ptp(xy[:, 0]) | ||
| y_range = np.ptp(xy[:, 1]) | ||
| n_points = len(xy) | ||
| # avg grid spacing ≈ sqrt(area / n_points) | ||
| avg_spacing = np.sqrt(x_range * y_range / n_points) | ||
| # mesh is ~3x coarser than the grid | ||
| return float(avg_spacing * 3) | ||
|
|
||
|
|
||
| def create_graph_from_datastore( | ||
| datastore, | ||
| output_root_path, | ||
| archetype="keisler", | ||
| mesh_node_distance=None, | ||
| level_refinement_factor=3, | ||
| max_num_levels=None, | ||
| ): | ||
| """Create graph using weather-model-graphs and save in neural-lam format. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| datastore : BaseRegularGridDatastore | ||
| Datastore providing grid coordinates. | ||
| output_root_path : str | ||
| Directory where the .pt graph files will be saved. | ||
| archetype : str | ||
| Graph archetype to create: ``"keisler"``, ``"graphcast"``, or | ||
| ``"hierarchical"``. | ||
| mesh_node_distance : float or None | ||
| Distance between created mesh nodes (in coordinate units). If None, | ||
| automatically estimated from the grid spacing. | ||
| level_refinement_factor : int | ||
| Refinement factor between mesh hierarchy levels. Only used for | ||
| ``"graphcast"`` and ``"hierarchical"`` archetypes. | ||
| max_num_levels : int or None | ||
| Maximum number of mesh hierarchy levels. Only used for ``"graphcast"`` | ||
| and ``"hierarchical"`` archetypes. | ||
| """ | ||
| if not isinstance(datastore, BaseRegularGridDatastore): | ||
| raise NotImplementedError( | ||
| "Only graph creation for BaseRegularGridDatastore is supported" | ||
| ) | ||
|
|
||
| if archetype not in ARCHETYPE_FUNCTIONS: | ||
| raise ValueError( | ||
| f"Unknown archetype '{archetype}'. " | ||
| f"Must be one of: {list(ARCHETYPE_FUNCTIONS.keys())}" | ||
| ) | ||
|
|
||
| xy = datastore.get_xy(category="state", stacked=False) | ||
|
|
||
|
leifdenby marked this conversation as resolved.
Outdated
|
||
| # wmg expects coords as 2D array of shape (num_nodes, 2), but the | ||
| # datastore may return a 3D array of shape (Nx, Ny, 2) when | ||
| # stacked=False. Reshape to (N, 2) for wmg. | ||
| xy = np.array(xy) | ||
| if xy.ndim == 3: | ||
| xy = xy.reshape(-1, 2) | ||
|
|
||
| if mesh_node_distance is None: | ||
| mesh_node_distance = _estimate_mesh_node_distance(xy) | ||
|
|
||
| # Build keyword arguments for the archetype function | ||
| archetype_kwargs = dict( | ||
| coords=xy, | ||
| mesh_node_distance=mesh_node_distance, | ||
| return_components=True, | ||
|
leifdenby marked this conversation as resolved.
|
||
| ) | ||
|
|
||
| # Only multiscale/hierarchical archetypes accept these parameters | ||
| if archetype in ("graphcast", "hierarchical"): | ||
| archetype_kwargs["level_refinement_factor"] = level_refinement_factor | ||
| archetype_kwargs["max_num_levels"] = max_num_levels | ||
|
|
||
| archetype_fn = ARCHETYPE_FUNCTIONS[archetype] | ||
| graph_components = archetype_fn(**archetype_kwargs) | ||
|
|
||
| hierarchical = archetype == "hierarchical" | ||
|
|
||
| wmg.save.to_neural_lam( | ||
| graph_components=graph_components, | ||
| output_directory=output_root_path, | ||
| hierarchical=hierarchical, | ||
| ) | ||
|
|
||
|
|
||
| def cli(input_args=None): | ||
| """Command-line interface for graph creation using weather-model-graphs.""" | ||
| parser = ArgumentParser( | ||
| description="Graph generation for neural-lam using " | ||
| "weather-model-graphs (wmg)", | ||
| formatter_class=ArgumentDefaultsHelpFormatter, | ||
| ) | ||
| parser.add_argument( | ||
| "--config_path", | ||
| type=str, | ||
| help="Path to neural-lam configuration file", | ||
| ) | ||
| parser.add_argument( | ||
| "--name", | ||
| type=str, | ||
| default="multiscale", | ||
| help="Name to save graph as (used as subdirectory name)", | ||
| ) | ||
| parser.add_argument( | ||
| "--archetype", | ||
| type=str, | ||
| default="keisler", | ||
| choices=["keisler", "graphcast", "hierarchical"], | ||
| help="Graph archetype to create", | ||
| ) | ||
| parser.add_argument( | ||
| "--mesh_node_distance", | ||
| type=float, | ||
| default=None, | ||
| help="Distance between mesh nodes (in coordinate units). " | ||
| "If not set, estimated automatically from grid spacing.", | ||
| ) | ||
| parser.add_argument( | ||
| "--level_refinement_factor", | ||
| type=int, | ||
| default=3, | ||
| help="Refinement factor between mesh hierarchy levels " | ||
| "(only used for graphcast and hierarchical)", | ||
| ) | ||
| parser.add_argument( | ||
| "--max_num_levels", | ||
| type=int, | ||
| default=None, | ||
| help="Maximum number of mesh levels " | ||
| "(only used for graphcast and hierarchical)", | ||
| ) | ||
| args = parser.parse_args(input_args) | ||
|
|
||
| assert ( | ||
| args.config_path is not None | ||
| ), "Specify your config with --config_path" | ||
|
|
||
| # Load neural-lam configuration and datastore to use | ||
| _, datastore = load_config_and_datastore(config_path=args.config_path) | ||
|
|
||
| create_graph_from_datastore( | ||
| datastore=datastore, | ||
| output_root_path=os.path.join(datastore.root_path, "graph", args.name), | ||
| archetype=args.archetype, | ||
| mesh_node_distance=args.mesh_node_distance, | ||
| level_refinement_factor=args.level_refinement_factor, | ||
| max_num_levels=args.max_num_levels, | ||
| ) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| cli() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,17 +30,21 @@ dependencies = [ | |
| "matplotlib>=3.7.0", | ||
| "plotly>=5.15.0", | ||
| "torch>=2.3.0", | ||
| "torch-geometric==2.3.1", | ||
| "torch-geometric>=2.5.3", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this a necessary change?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is necessary. The weather-model-graphs[pytorch] extra requires
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, in that case I think we maybe should reduce required version number on the
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You're right — wmg only uses |
||
| "parse>=1.20.2", | ||
| "dataclass-wizard<0.31.0", | ||
| "mllam-data-prep>=0.5.0", | ||
| "mlflow>=2.16.2", | ||
| "boto3>=1.35.32", | ||
| "nvidia-ml-py>=13.580.82", | ||
| "pillow>=9.0.0", | ||
| "weather-model-graphs", | ||
| ] | ||
| requires-python = ">=3.10" | ||
|
|
||
| [project.scripts] | ||
| create_graph_with_wmg = "neural_lam.create_graph_with_wmg:cli" | ||
|
leifdenby marked this conversation as resolved.
|
||
|
|
||
| [dependency-groups] | ||
| dev = ["pre-commit>=3.8.0", "pytest>=8.3.2", "pooch>=1.8.2"] | ||
|
|
||
|
|
@@ -130,3 +134,6 @@ exclude = [ | |
| ".venv/", | ||
| "venv/", | ||
| ] | ||
|
|
||
| [tool.uv.sources] | ||
| weather-model-graphs = { git = "https://github.com/prajwal-tech07/weather-model-graphs", rev = "issue-384/to-neural-lam" } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| # Standard library | ||
| import tempfile | ||
| import warnings | ||
| from pathlib import Path | ||
|
|
||
| # Third-party | ||
|
|
@@ -8,6 +9,9 @@ | |
|
|
||
| # First-party | ||
| from neural_lam.create_graph import create_graph_from_datastore | ||
| from neural_lam.create_graph_with_wmg import ( | ||
| create_graph_from_datastore as wmg_create_graph_from_datastore, | ||
| ) | ||
| from neural_lam.datastore import DATASTORES | ||
| from neural_lam.datastore.base import BaseRegularGridDatastore | ||
| from tests.conftest import init_datastore_example | ||
|
|
@@ -117,3 +121,102 @@ def test_graph_creation(datastore_name, graph_name): | |
| assert r.shape[0] == 2 # adjacency matrix uses two rows | ||
| elif file_id.endswith("_features"): | ||
| assert r.shape[1] == d_features | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("archetype", ["keisler", "graphcast", "hierarchical"]) | ||
| @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) | ||
| def test_wmg_graph_creation(datastore_name, archetype): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great, but we could actually use the testing that #323 introduces. I should try and get that finished so that we can merge both in together :)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds like a great plan — either way works for me. Happy to adapt the tests to #323's infrastructure once it lands, or merge as-is if this PR is ready |
||
| """Check that graph creation via weather-model-graphs produces the | ||
| expected .pt files with correct shapes and types.""" | ||
| datastore = init_datastore_example(datastore_name) | ||
|
|
||
| if not isinstance(datastore, BaseRegularGridDatastore): | ||
| pytest.skip( | ||
| f"Skipping test for {datastore_name} as it is not a regular " | ||
| "grid datastore." | ||
| ) | ||
|
|
||
| hierarchical = archetype == "hierarchical" | ||
|
|
||
| required_graph_files = [ | ||
| "m2m_edge_index.pt", | ||
| "g2m_edge_index.pt", | ||
| "m2g_edge_index.pt", | ||
| "m2m_features.pt", | ||
| "g2m_features.pt", | ||
| "m2g_features.pt", | ||
| "mesh_features.pt", | ||
| ] | ||
| if hierarchical: | ||
| required_graph_files.extend( | ||
| [ | ||
| "mesh_up_edge_index.pt", | ||
| "mesh_down_edge_index.pt", | ||
| "mesh_up_features.pt", | ||
| "mesh_down_features.pt", | ||
| ] | ||
| ) | ||
|
|
||
| d_features = 3 | ||
| d_mesh_static = 2 | ||
|
|
||
| with tempfile.TemporaryDirectory() as tmpdir: | ||
| graph_dir_path = Path(tmpdir) / "graph" / archetype | ||
|
|
||
| wmg_create_graph_from_datastore( | ||
| datastore=datastore, | ||
| output_root_path=str(graph_dir_path), | ||
| archetype=archetype, | ||
| ) | ||
|
|
||
| assert graph_dir_path.exists() | ||
|
|
||
| # check that all the required files are present | ||
| for file_name in required_graph_files: | ||
| assert (graph_dir_path / file_name).exists() | ||
|
|
||
| # try to load each and ensure they have the right shape | ||
| for file_name in required_graph_files: | ||
| file_id = Path(file_name).stem | ||
| result = torch.load(graph_dir_path / file_name, weights_only=True) | ||
|
|
||
| if file_id.startswith("g2m") or file_id.startswith("m2g"): | ||
| assert isinstance(result, torch.Tensor) | ||
|
|
||
| if file_id.endswith("_index"): | ||
| assert result.shape[0] == 2 | ||
| elif file_id.endswith("_features"): | ||
| assert result.shape[1] == d_features | ||
|
|
||
| elif file_id.startswith("m2m") or file_id.startswith("mesh"): | ||
| assert isinstance(result, list) | ||
|
|
||
| for r in result: | ||
| assert isinstance(r, torch.Tensor) | ||
|
|
||
| if file_id == "mesh_features": | ||
| assert r.shape[1] == d_mesh_static | ||
| elif file_id.endswith("_index"): | ||
| assert r.shape[0] == 2 | ||
| elif file_id.endswith("_features"): | ||
| assert r.shape[1] == d_features | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) | ||
| def test_old_create_graph_deprecation_warning(datastore_name): | ||
| """Check that the old create_graph CLI emits a deprecation warning.""" | ||
| # First-party | ||
| from neural_lam.create_graph import cli as old_cli | ||
|
|
||
| with warnings.catch_warnings(record=True) as w: | ||
| warnings.simplefilter("always") | ||
| try: | ||
| old_cli(["--config_path", "nonexistent.yaml"]) | ||
| except Exception: | ||
| pass # We only care about the warning, not the error | ||
|
|
||
| deprecation_warnings = [ | ||
| x for x in w if issubclass(x.category, DeprecationWarning) | ||
| ] | ||
| assert len(deprecation_warnings) >= 1 | ||
| assert "create_graph_with_wmg" in str(deprecation_warnings[0].message) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will make a separate PR for this. I think this is some general CI maintenance that should be merged in before this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Noted — happy to revert the CI changes from this branch if you'd prefer to handle them in a separate PR.