Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a4b8797
fix #375
kshirajahere Mar 10, 2026
fcc5b3c
Rename plot_error_map to plot_error_heatmap
kshirajahere Mar 10, 2026
72ebc12
Implement HeatmapDatastore for plotting tests
kshirajahere Mar 10, 2026
7052efc
Update CHANGELOG with recent fixes
kshirajahere Mar 10, 2026
4069d62
Apply review fixes: adaptive layout, NaN guard, deprecated wrapper, d…
kshirajahere Mar 10, 2026
cc0231d
Improve error heatmap relative scaling
kshirajahere Mar 16, 2026
962ea1d
Merge origin/main into PR 376 plotting branch
kshirajahere Mar 20, 2026
d05d97d
Merge branch 'main' into fix/-#375-Improvements-to-plot_error_map-fun…
kshirajahere Mar 23, 2026
a006b9e
fix: address PR 376 heatmap review feedback
kshirajahere Mar 25, 2026
3e6cccf
Address final plotting review comments
kshirajahere Mar 28, 2026
63c4186
Fix flake8 line length in plotting tests
kshirajahere Mar 28, 2026
45c965c
Address remaining plotting review nits
kshirajahere Mar 28, 2026
2fa7bb8
ci: retrigger cancelled GPU checks
kshirajahere Mar 30, 2026
7e6108e
Merge branch 'main' into fix/-#375-Improvements-to-plot_error_map-fun…
sadamov Apr 1, 2026
bf7580e
Merge branch 'main' into fix/-#375-Improvements-to-plot_error_map-fun…
kshirajahere Apr 3, 2026
fd7a6aa
docs: clarify heatmap scaling and fallback behavior
kshirajahere Apr 14, 2026
df26bbe
Merge origin/main into fix/-#375-Improvements-to-plot_error_map-function
kshirajahere Apr 14, 2026
ec76907
Merge branch 'main' into fix/-#375-Improvements-to-plot_error_map-fun…
sadamov Apr 16, 2026
a65ece7
refactor: replace implicit heatmap normalization chain with explicit …
sadamov Apr 17, 2026
85d68b6
precommits
sadamov Apr 17, 2026
1e6f50f
undo some unnecessary changes
sadamov Apr 17, 2026
11a925d
Merge pull request #1 from sadamov/fix/explicit-normalization-parameter
kshirajahere Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Replace `shell=True` subprocess call in `compute_standardization_stats.py` with a safe argument list and Python-side hostname parsing to prevent command injection via `SLURM_JOB_NODELIST` [\#264](https://github.com/mllam/neural-lam/pull/264) @ashum9

- Avoid NaN when standardizing fields with zero std [#189](https://github.com/mllam/neural-lam/pull/189) @varunsiravuri
- Improve metric heatmaps by renaming `plot_error_map`, using a relative
cross-variable color scale while keeping original values in annotations, and
scaling figure size, tick labels, and annotation text with the number of
variables and lead times so larger evaluation outputs remain readable
([#375](https://github.com/mllam/neural-lam/issues/375))
Comment thread
kshirajahere marked this conversation as resolved.
Outdated
- Replaces multiple `assert` statements used for runtime input validation with explicit `ValueError` [\#279](https://github.com/mllam/neural-lam/pull/279) @Sir-Sloth-The-Lazy

- Fix README image paths to use absolute GitHub URLs so images display correctly on PyPI [\#188](https://github.com/mllam/neural-lam/pull/188) @bk-simon
Expand Down
16 changes: 9 additions & 7 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def on_validation_epoch_end(self):
"""
Compute val metrics at the end of val epoch
"""
# Create error maps for all test metrics
# Create error heatmaps for all validation metrics
self.aggregate_and_plot_metrics(self.val_metrics, prefix="val")

# Clear lists with validation metrics values
Expand Down Expand Up @@ -421,9 +421,10 @@ def test_step(self, batch, batch_idx):
batch_size=batch[0].shape[0],
)

# Compute all evaluation metrics for error maps Note: explicitly list
# metrics here, as test_metrics can contain additional ones, computed
# differently, but that should be aggregated on_test_epoch_end
# Compute all evaluation metrics for error heatmaps. Note:
# explicitly list metrics here, as test_metrics can contain
# additional ones, computed differently, but that should be
# aggregated on_test_epoch_end
for metric_name in ("mse", "mae"):
metric_func = metrics.get_metric(metric_name)
batch_metric_vals = metric_func(
Expand Down Expand Up @@ -610,7 +611,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
Return: log_dict: dict with everything to log for given metric
"""
log_dict = {}
metric_fig = vis.plot_error_map(
metric_fig = vis.plot_error_heatmap(
errors=metric_tensor,
datastore=self._datastore,
)
Expand Down Expand Up @@ -642,7 +643,8 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):

def aggregate_and_plot_metrics(self, metrics_dict, prefix):
"""
Aggregate and create error map plots for all metrics in metrics_dict
Aggregate and create error heatmap plots for all metrics in
metrics_dict

metrics_dict: dictionary with metric_names and list of tensors
with step-evals.
Expand Down Expand Up @@ -700,7 +702,7 @@ def on_test_epoch_end(self):
Compute test metrics and make plots at the end of test epoch. Will
gather stored tensors and perform plotting and logging on rank 0.
"""
# Create error maps for all test metrics
# Create error heatmaps for all test metrics
self.aggregate_and_plot_metrics(self.test_metrics, prefix="test")

# Plot spatial loss maps
Expand Down
Loading
Loading