336 ema model weight averaging#338
Conversation
|
@abhaygoudannavar the files changes indicate to solve the pinned memory worker issue and nothing related to EMA. can you double check? |
66604ac to
fa45696
Compare
|
Hey @sadamov! First off—thank you for catching that, you're totally right. When I tried to clean up the commit history by resetting the branch directly, GitHub automatically closed this PR. Since this one got pretty messy under the hood, I thought it would be best to just start fresh. I've opened a clean, new PR with the proper EMA implementation (callbacks, tests, CLI args, etc.) right here: #362 Thanks again for reviewing, and sorry for the extra noise on this one! I'll leave this PR closed and we can continue the review over on the new one. |
Add an Exponential Moving Average (EMA) callback for model weights via a PyTorch Lightning
Callback subclass. EMA maintains a shadow copy of parameters as a running average (θ_ema ← decay * θ_ema + (1 - decay) * θ_current), used for validation/test/inference while training continues on raw optimizer-updated weights. This reduces per-step noise compounding during autoregressive rollouts and produces more stable checkpoints.
Changes:
Motivation:
EMA is standard in every major neural weather prediction system (GraphCast reports 5–12% RMSE reduction over 10-day rollouts). neural-lam is currently the only major open-source neural weather framework without it.
Dependencies:
None beyond existing PyTorch + PyTorch Lightning.
Issue Link #336
Type of change
Checklist before requesting a review
pullwith--rebaseoption if possible).Checklist for reviewers
Each PR comes with its own improvements and flaws. The reviewer should check the following:
Author checklist after completed review
reflecting type of change (add section where missing):
Checklist for assignee