Skip to content

Add create_graph_with_wmg.py CLI using weather-model-graphs#596

Open
prajwal-tech07 wants to merge 8 commits intomllam:mainfrom
prajwal-tech07:issue-384/create-graph-with-wmg
Open

Add create_graph_with_wmg.py CLI using weather-model-graphs#596
prajwal-tech07 wants to merge 8 commits intomllam:mainfrom
prajwal-tech07:issue-384/create-graph-with-wmg

Conversation

@prajwal-tech07
Copy link
Copy Markdown

Describe your changes

Add a new CLI script create_graph_with_wmg.py that delegates graph creation
to weather-model-graphs (wmg)
and saves the output using wmg.save.to_neural_lam() in the tensor-on-disk
format expected by neural_lam.utils.load_graph().

This is the neural-lam side of the bridge described in #384. The wmg side
is mllam/weather-model-graphs#123,
which adds to_neural_lam() to wmg.save.

What this PR does:

  • Adds neural_lam/create_graph_with_wmg.py with support for all three wmg
    archetypes: keisler (flat single-scale), graphcast (flat multiscale),
    and hierarchical (Oskarsson)
  • Auto-computes mesh_node_distance from grid spacing when not specified
  • Reshapes datastore coordinates from (Nx, Ny, 2)(N, 2) for wmg
  • Adds a deprecation warning to the old create_graph.py CLI pointing users
    to the new script
  • Adds weather-model-graphs[pytorch]>=0.3.0 as a dependency
  • Relaxes torch-geometric pin from ==2.3.1 to >=2.5.3 (required by
    wmg's pytorch extra)
  • Registers create_graph_with_wmg as a console script entry point

Dependencies: Requires mllam/weather-model-graphs#123
to be merged and released first (adds wmg.save.to_neural_lam()).

Files changed (4 files, +305 −1):

File Change
neural_lam/create_graph_with_wmg.py New CLI script (+188 lines)
neural_lam/create_graph.py Added deprecation warning (+9 lines)
pyproject.toml Added wmg dep, relaxed torch-geometric, added script entry (+5 −1)
tests/test_graph_creation.py 12 new tests (9 wmg creation + 3 deprecation) (+103 lines)

Issue Link

Solves #384 (neural-lam side)

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug
    • maintenance: when your contribution is relates to repo maintenance, e.g. CI/CD or documentation

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • (if the PR is not just maintenance/bugfix) the PR is assigned to the next milestone. If it is not, propose it for a future milestone.
  • author has added an entry to the changelog (and designated the change as added, changed, fixed or maintenance)
  • Once the PR is ready to be merged, squash commits and merge the PR.

@leifdenby leifdenby added this to the v0.7.0 milestone Apr 13, 2026
prajwal-tech07 and others added 2 commits April 14, 2026 01:09
Update pyproject.toml to install weather-model-graphs from the
issue-384/to-neural-lam branch of the fork, so that CI and reviewers
can test the neural-lam side against the unreleased to_neural_lam()
changes before wmg PR mllam#123 is merged.

Will revert to a versioned PyPI dependency once PR mllam#123 is released.
@prajwal-tech07
Copy link
Copy Markdown
Author

Merged Leif's pyproject.toml fix from #1. The pip, cpu CI failure is expected since tool.uv.sources isn't read by pip — only the uv jobs will correctly resolve weather-model-graphs from the fork branch. The 3 cancelled checks (pip gpu, uv cpu, uv gpu) were just cascading from the pip,cpu failure. All 21 tests pass locally.

Copy link
Copy Markdown
Member

@leifdenby leifdenby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! A few suggestions for changes 🚀

Comment thread neural_lam/create_graph_with_wmg.py Outdated
}


def _estimate_mesh_node_distance(xy):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we could call this _estimate_grid_node_spacing? And then expose the grid-mesh-spacing ratio as a CLI arg that defaults to 3.0?

Copy link
Copy Markdown
Author

@prajwal-tech07 prajwal-tech07 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed _estimate_mesh_node_distance → _estimate_grid_node_spacing — it now returns only the average grid spacing. The ×3 multiplier is replaced by a new grid_mesh_ratio parameter (default 3.0) exposed as --grid_mesh_ratio on the CLI.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, but it would better to name it --grid_mesh_spacing_ratio I think - just be clear about what it is the ratio between :) what do you think?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--grid_mesh_ratio--grid_mesh_spacing_ratio
Good call — renamed the CLI arg, function parameter, and docstrings throughout to make it clear it's the ratio between mesh-node and grid-node spacing.

Comment thread neural_lam/create_graph_with_wmg.py Outdated
Comment thread neural_lam/create_graph_with_wmg.py

@pytest.mark.parametrize("archetype", ["keisler", "graphcast", "hierarchical"])
@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_wmg_graph_creation(datastore_name, archetype):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great, but we could actually use the testing that #323 introduces. I should try and get that finished so that we can merge both in together :)

Copy link
Copy Markdown
Author

@prajwal-tech07 prajwal-tech07 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the tests are working well as-is! Happy to adapt them to the testing infrastructure from #323 once that's ready — makes total sense to merge them together. Let me know if there's anything I can do to help with #323!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, how about we say that depending on whether I finish #323 soon then we either a) merge this in with the tests have you have them implemented already or b) merge #323 in first and then adapting the testing here to use the testing being introduced in #323?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a great plan — either way works for me. Happy to adapt the tests to #323's infrastructure once it lands, or merge as-is if this PR is ready

Comment thread pyproject.toml Outdated
"plotly>=5.15.0",
"torch>=2.3.0",
"torch-geometric==2.3.1",
"torch-geometric>=2.5.3",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a necessary change?

Copy link
Copy Markdown
Author

@prajwal-tech07 prajwal-tech07 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is necessary. The weather-model-graphs[pytorch] extra requires torch-geometric>=2.5.3 in its own dependencies, so we need to allow at least that version here too.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, in that case I think we maybe should reduce required version number on the weather-model-graphs side, since all we are doing with pytorch-geometric is using it to convert the networkx.DiGraph objects to torch.Tensor objects, and we already do that in the current create_graphs.py code in neural-lam with the older version. So how about we change weather-model-graphs to require torch-geometric==2.3.1 so the two are in sync? Then we don't have to change anything on the neural-lam side

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right — wmg only uses from_networkx which works fine with 2.3.1. I've also pushed a commit to the wmg PR branch (prajwal-tech07/weather-model-graphs@d0c693d) pinning it to ==2.3.1 there too, so both repos stay in sync.

Comment thread pyproject.toml
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will make a separate PR for this. I think this is some general CI maintenance that should be merged in before this PR.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted — happy to revert the CI changes from this branch if you'd prefer to handle them in a separate PR.

…sh_ratio, use stacked=True, add return_components comment, add README entry
Copy link
Copy Markdown
Member

@leifdenby leifdenby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment thread README.md
Comment thread pyproject.toml
@sadamov sadamov linked an issue Apr 15, 2026 that may be closed by this pull request
Copy link
Copy Markdown
Member

@leifdenby leifdenby left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, the CLI interface looks great now, as do the readme instructions and install works great (just tried locally on my laptop too).

I think we should also change the tests that use graph creation (i.e. the training examples, they all use the create_graph_from_datastore() function in https://github.com/mllam/neural-lam/blob/main/neural_lam/create_graph.py#L540), but I think we should use the new one that you have created that creates a graph with wmg instead.

Migrated all test files that used the old create_graph_from_datastore()
from neural_lam.create_graph to use the new wmg-based version from
neural_lam.create_graph_with_wmg instead.

Changes:
- test_datasets.py: Use wmg create_graph_from_datastore with archetype='keisler'
- test_clamping.py: Use wmg create_graph_from_datastore with archetype='keisler'
- test_plotting.py: Use wmg create_graph_from_datastore with archetype='keisler'
- test_training.py: Use wmg create_graph_from_datastore with archetype='keisler'
- test_plot_graph.py: Use wmg create_graph_from_datastore with keisler and
  hierarchical archetypes. Removed multiscale (graphcast) parametrization
  since the graphcast archetype produces multi-level m2m edges without
  up/down edges, which is not yet compatible with utils.load_graph().
  Graphcast graph creation is separately tested in test_graph_creation.py.
@prajwal-tech07
Copy link
Copy Markdown
Author

Thanks for the suggestion, @leifdenby! Done in 76c7ed7.

All test files that used the old create_graph_from_datastore() from neural_lam.create_graph have been migrated to use the new wmg-based version from neural_lam.create_graph_with_wmg:

  • test_datasets.py, test_clamping.py, test_plotting.py, test_training.py: Switched to create_graph_from_datastore(archetype="keisler") (equivalent to the previous n_max_levels=1 flat graph).
  • test_plot_graph.py: Switched to wmg-based graph creation with keisler and hierarchical archetypes. Removed the multiscale (graphcast) parametrization from this test since the graphcast archetype produces multi-level m2m edges without up/down edges, which isn't compatible with utils.load_graph() (it infers hierarchical=True when n_levels > 1 and then tries to load mesh_up_edge_index.pt). Graphcast graph creation is already separately tested in test_graph_creation.py where all 3 archetypes pass.

All tests pass locally (55/55 in the modified files, 72/72 in the full suite.

@leifdenby
Copy link
Copy Markdown
Member

leifdenby commented Apr 15, 2026

I just completed a 200 epoch training (on a DGX Spark machine) on the DANRA test dataset using a keisler graph created with your new create_graph_with_wmg script and the training loss decreases as expected 🥳

W B Chart 15 4 2026, 15 39 26

I ran the following commands:

uv run python -m neural_lam.create_graph_with_wmg --config_path tests/datastore_examples/mdp/danra_100m_winds/config.yaml --archetype keisler
uv run python -m neural_lam.train_model --config_path tests/datastore_examples/mdp/danra_100m_winds/config.yaml

So I will give this a final review, but I think this is nearly ready to merge :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Supporting tensor-on-disk-format from weather-model-graphs

2 participants