Skip to content

feat: add EMA model weight averaging callback via --ema_decay#362

Open
abhaygoudannavar wants to merge 6 commits intomllam:mainfrom
abhaygoudannavar:feature/add-ema-callback
Open

feat: add EMA model weight averaging callback via --ema_decay#362
abhaygoudannavar wants to merge 6 commits intomllam:mainfrom
abhaygoudannavar:feature/add-ema-callback

Conversation

@abhaygoudannavar
Copy link
Copy Markdown
Contributor

@abhaygoudannavar abhaygoudannavar commented Mar 9, 2026

Describe your changes

This PR introduces an Exponential Moving Average (EMA) callback for model weights via a PyTorch Lightning Callback subclass.
EMA maintains a shadow copy of parameters as a running average:
ema_weights = decay * ema_weights + (1 - decay) * current_weights.
These EMA weights are used during validation, testing, and inference, while training continues on the raw, optimizer-updated weights.
Implementation details:

  • neural_lam/callbacks.py (new): EMACallback(pl.Callback) with hooks for weight initialization, in-place lerp_ updates, weight swapping during evaluation, and checkpoint persistence (handling device transfers).
  • neural_lam/train_model.py: Added --ema_decay CLI argument (default None = disabled; e.g. 0.999). When set, appends EMACallback to the trainer's callbacks.
  • tests/test_ema_callback.py (new): 7 comprehensive tests covering decay validation, math correctness, weight swapping, checkpoint round-trips, integration with GraphLAM, and no-op behavior when disabled.
  • CHANGELOG.md: Added entry under [unreleased] > Added.
    Motivation:
    EMA is a standard technique in modern neural weather prediction (e.g., GraphCast reports 5–12% RMSE reduction using EMA). It reduces per-step noise compounding during autoregressive rollouts and produces significantly more stable model checkpoints.
    Dependencies:
  • None beyond existing torch and pytorch_lightning.

Issue Link

solves #336

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug
    • maintenance: when your contribution is relates to repo maintenance, e.g. CI/CD or documentation

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • (if the PR is not just maintenance/bugfix) the PR is assigned to the next milestone. If it is not, propose it for a future milestone.
  • author has added an entry to the changelog (and designated the change as added, changed, fixed or maintenance)
  • Once the PR is ready to be merged, squash commits and merge the PR.

@abhaygoudannavar abhaygoudannavar mentioned this pull request Mar 9, 2026
21 tasks
@sadamov

This comment was marked as outdated.

@sadamov sadamov self-requested a review March 12, 2026 19:47
@sadamov sadamov added the enhancement New feature or request label Mar 12, 2026
@sadamov sadamov requested a review from joeloskarsson April 1, 2026 09:29
@sadamov sadamov self-assigned this Apr 1, 2026
Copy link
Copy Markdown
Collaborator

@sadamov sadamov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@abhaygoudannavar thanks a lot for this suggestion to add EMA, I wasn't aware of it potentially improving forecast skill by ~5-10%. and thanks for the patience :)

I am not an expert on EMA and added Joel to validate the design decisions.

The motivation is solid, and the implementation as a callback seems smart. The math, checkpoint persistence, and test coverage are all reasonable.

The main problem is that train_model.py has no --ema_decay CLI argument, so the callback cannot actually be enabled, right? Beyond that there are some considerations about when to move to which device. But more about that in later detailed reviews.

@abhaygoudannavar
Copy link
Copy Markdown
Contributor Author

@sadamov You're absolutely right — the --ema_decay CLI argument and the wiring in train_model.py are missing from the current diff. It looks like those changes got lost during one of the merge-with-main cycles. Good catch!

I'll push an update that:

Adds the --ema_decay CLI argument to train_model.py (default None = disabled)
Conditionally appends EMACallback(decay=args.ema_decay) to the trainer's callbacks list when --ema_decay is set
Regarding the device placement considerations you mentioned — I'm happy to discuss those in the detailed review. The current on_load_checkpoint does a .to(device) transfer, but I can see there might be edge cases (e.g., multi-GPU DDP) worth addressing. Looking forward to your and Joel's feedback on that.

Will update the PR shortly!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants