feat: add EMA model weight averaging callback via --ema_decay#362
feat: add EMA model weight averaging callback via --ema_decay#362abhaygoudannavar wants to merge 6 commits intomllam:mainfrom
Conversation
This comment was marked as outdated.
This comment was marked as outdated.
There was a problem hiding this comment.
@abhaygoudannavar thanks a lot for this suggestion to add EMA, I wasn't aware of it potentially improving forecast skill by ~5-10%. and thanks for the patience :)
I am not an expert on EMA and added Joel to validate the design decisions.
The motivation is solid, and the implementation as a callback seems smart. The math, checkpoint persistence, and test coverage are all reasonable.
The main problem is that train_model.py has no --ema_decay CLI argument, so the callback cannot actually be enabled, right? Beyond that there are some considerations about when to move to which device. But more about that in later detailed reviews.
|
@sadamov You're absolutely right — the --ema_decay CLI argument and the wiring in train_model.py are missing from the current diff. It looks like those changes got lost during one of the merge-with-main cycles. Good catch! I'll push an update that: Adds the --ema_decay CLI argument to train_model.py (default None = disabled) Will update the PR shortly! |
Describe your changes
This PR introduces an Exponential Moving Average (EMA) callback for model weights via a PyTorch Lightning Callback subclass.
EMA maintains a shadow copy of parameters as a running average:
ema_weights = decay * ema_weights + (1 - decay) * current_weights.These EMA weights are used during validation, testing, and inference, while training continues on the raw, optimizer-updated weights.
Implementation details:
lerp_updates, weight swapping during evaluation, and checkpoint persistence (handling device transfers).--ema_decayCLI argument (defaultNone= disabled; e.g.0.999). When set, appends EMACallback to the trainer's callbacks.GraphLAM, and no-op behavior when disabled.[unreleased] > Added.Motivation:
EMA is a standard technique in modern neural weather prediction (e.g., GraphCast reports 5–12% RMSE reduction using EMA). It reduces per-step noise compounding during autoregressive rollouts and produces significantly more stable model checkpoints.
Dependencies:
torchandpytorch_lightning.Issue Link
solves #336
Type of change
Checklist before requesting a review
pullwith--rebaseoption if possible).Checklist for reviewers
Each PR comes with its own improvements and flaws. The reviewer should check the following:
Author checklist after completed review
reflecting type of change (add section where missing):
Checklist for assignee