Skip to content

336 ema model weight averaging#338

Closed
abhaygoudannavar wants to merge 0 commit intomllam:mainfrom
abhaygoudannavar:336-ema-model-weight-averaging
Closed

336 ema model weight averaging#338
abhaygoudannavar wants to merge 0 commit intomllam:mainfrom
abhaygoudannavar:336-ema-model-weight-averaging

Conversation

@abhaygoudannavar
Copy link
Copy Markdown
Contributor

Add an Exponential Moving Average (EMA) callback for model weights via a PyTorch Lightning
Callback subclass. EMA maintains a shadow copy of parameters as a running average (θ_ema ← decay * θ_ema + (1 - decay) * θ_current), used for validation/test/inference while training continues on raw optimizer-updated weights. This reduces per-step noise compounding during autoregressive rollouts and produces more stable checkpoints.
Changes:

  • neural_lam/callbacks.py (new): EMACallback(pl.Callback) with hooks for weight initialization, in-place lerp_ update, weight swapping during eval, and checkpoint persistence. Handles device transfer correctly when resuming from a checkpoint.
  • neural_lam/train_model.py: Added --ema_decay CLI argument (default None = disabled). When set (e.g. 0.999), appends EMACallback to the trainer's callback list.
  • tests/test_ema_callback.py(new): 7 tests covering decay validation, mathematical correctness of the running average, weight swap during validation/test, checkpoint save/load round-trip, and no-op behavior when EMA is disabled.
  • CHANGELOG.md: Added entry under added.

Motivation:
EMA is standard in every major neural weather prediction system (GraphCast reports 5–12% RMSE reduction over 10-day rollouts). neural-lam is currently the only major open-source neural weather framework without it.

Dependencies:
None beyond existing PyTorch + PyTorch Lightning.

Issue Link #336

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug
    • maintenance: when your contribution is relates to repo maintenance, e.g. CI/CD or documentation

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • (if the PR is not just maintenance/bugfix) the PR is assigned to the next milestone. If it is not, propose it for a future milestone.
  • author has added an entry to the changelog (and designated the change as added, changed, fixed or maintenance)
  • Once the PR is ready to be merged, squash commits and merge the PR.

@abhaygoudannavar
Copy link
Copy Markdown
Contributor Author

@sadamov i have fixed the issue #336 and i wanted to know if there are any changes that are need to be done.

@sadamov
Copy link
Copy Markdown
Collaborator

sadamov commented Mar 8, 2026

@abhaygoudannavar the files changes indicate to solve the pinned memory worker issue and nothing related to EMA. can you double check?

@sadamov sadamov added enhancement New feature or request question Further information is requested labels Mar 8, 2026
@abhaygoudannavar abhaygoudannavar force-pushed the 336-ema-model-weight-averaging branch from 66604ac to fa45696 Compare March 9, 2026 19:37
@abhaygoudannavar
Copy link
Copy Markdown
Contributor Author

Hey @sadamov! First off—thank you for catching that, you're totally right.
It looks like I accidentally branched off from my earlier pin_memory work instead of a clean
main branch, which is why none of the EMA code actually made it into the diff here.

When I tried to clean up the commit history by resetting the branch directly, GitHub automatically closed this PR. Since this one got pretty messy under the hood, I thought it would be best to just start fresh.

I've opened a clean, new PR with the proper EMA implementation (callbacks, tests, CLI args, etc.) right here: #362

Thanks again for reviewing, and sorry for the extra noise on this one! I'll leave this PR closed and we can continue the review over on the new one.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request question Further information is requested

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants