Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
a4b8797
fix #375
kshirajahere Mar 10, 2026
fcc5b3c
Rename plot_error_map to plot_error_heatmap
kshirajahere Mar 10, 2026
72ebc12
Implement HeatmapDatastore for plotting tests
kshirajahere Mar 10, 2026
7052efc
Update CHANGELOG with recent fixes
kshirajahere Mar 10, 2026
4069d62
Apply review fixes: adaptive layout, NaN guard, deprecated wrapper, d…
kshirajahere Mar 10, 2026
cc0231d
Improve error heatmap relative scaling
kshirajahere Mar 16, 2026
962ea1d
Merge origin/main into PR 376 plotting branch
kshirajahere Mar 20, 2026
d05d97d
Merge branch 'main' into fix/-#375-Improvements-to-plot_error_map-fun…
kshirajahere Mar 23, 2026
a006b9e
fix: address PR 376 heatmap review feedback
kshirajahere Mar 25, 2026
3e6cccf
Address final plotting review comments
kshirajahere Mar 28, 2026
63c4186
Fix flake8 line length in plotting tests
kshirajahere Mar 28, 2026
45c965c
Address remaining plotting review nits
kshirajahere Mar 28, 2026
2fa7bb8
ci: retrigger cancelled GPU checks
kshirajahere Mar 30, 2026
7e6108e
Merge branch 'main' into fix/-#375-Improvements-to-plot_error_map-fun…
sadamov Apr 1, 2026
bf7580e
Merge branch 'main' into fix/-#375-Improvements-to-plot_error_map-fun…
kshirajahere Apr 3, 2026
fd7a6aa
docs: clarify heatmap scaling and fallback behavior
kshirajahere Apr 14, 2026
df26bbe
Merge origin/main into fix/-#375-Improvements-to-plot_error_map-function
kshirajahere Apr 14, 2026
ec76907
Merge branch 'main' into fix/-#375-Improvements-to-plot_error_map-fun…
sadamov Apr 16, 2026
a65ece7
refactor: replace implicit heatmap normalization chain with explicit …
sadamov Apr 17, 2026
85d68b6
precommits
sadamov Apr 17, 2026
1e6f50f
undo some unnecessary changes
sadamov Apr 17, 2026
11a925d
Merge pull request #1 from sadamov/fix/explicit-normalization-parameter
kshirajahere Apr 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Change the default ensemble-loading behavior in `WeatherDataset` / `WeatherDataModule` to use all ensemble members as independent samples for ensemble datastores (with matching ensemble-member selection for forcing when available); single-member behavior now requires explicitly opting in via `--load_single_member` [\#332](https://github.com/mllam/neural-lam/pull/332) @kshirajahere
- Refactor graph loading: move zero-indexing out of the model and update plotting to prepare using the research-branch graph I/O [\#184](https://github.com/mllam/neural-lam/pull/184) @zweihuehner
- Replace `print()`-based `rank_zero_print` with `loguru` `logger.info()` for structured log-level control ([#33](https://github.com/mllam/neural-lam/issues/33))
- Change metric heatmap (`plot_error_map`, now `plot_error_heatmap`) to use a
shared cross-variable color scale instead of per-row normalization, add a
colorbar, and scale figure size and font sizes with grid dimensions
([#375](https://github.com/mllam/neural-lam/issues/375))

### Fixed

Expand Down
16 changes: 9 additions & 7 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def on_validation_epoch_end(self):
"""
Compute val metrics at the end of val epoch
"""
# Create error maps for all test metrics
# Create error heatmaps for all validation metrics
self.aggregate_and_plot_metrics(self.val_metrics, prefix="val")

if self.trainer.is_global_zero:
Expand Down Expand Up @@ -436,9 +436,10 @@ def test_step(self, batch, batch_idx):
batch_size=batch[0].shape[0],
)

# Compute all evaluation metrics for error maps Note: explicitly list
# metrics here, as test_metrics can contain additional ones, computed
# differently, but that should be aggregated on_test_epoch_end
# Compute all evaluation metrics for error heatmaps. Note:
# explicitly list metrics here, as test_metrics can contain
# additional ones, computed differently, but that should be
# aggregated on_test_epoch_end
for metric_name in ("mse", "mae"):
metric_func = metrics.get_metric(metric_name)
batch_metric_vals = metric_func(
Expand Down Expand Up @@ -669,7 +670,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
Return: log_dict: dict with everything to log for given metric
"""
log_dict = {}
metric_fig = vis.plot_error_map(
metric_fig = vis.plot_error_heatmap(
errors=metric_tensor,
datastore=self._datastore,
)
Expand Down Expand Up @@ -702,7 +703,8 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):

def aggregate_and_plot_metrics(self, metrics_dict, prefix):
"""
Aggregate and create error map plots for all metrics in metrics_dict
Aggregate and create error heatmap plots for all metrics in
metrics_dict

metrics_dict: dictionary with metric_names and list of tensors
with step-evals.
Expand Down Expand Up @@ -760,7 +762,7 @@ def on_test_epoch_end(self):
Compute test metrics and make plots at the end of test epoch. Will
gather stored tensors and perform plotting and logging on rank 0.
"""
# Create error maps for all test metrics
# Create error heatmaps for all test metrics
self.aggregate_and_plot_metrics(self.test_metrics, prefix="test")

# Plot spatial loss maps
Expand Down
267 changes: 230 additions & 37 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Standard library
import warnings

# Third-party
import cartopy.crs as ccrs
import cartopy.feature as cfeature
Expand All @@ -12,11 +15,20 @@
from . import utils
from .datastore.base import BaseRegularGridDatastore

# Font sizes shared across all plot functions for visual consistency.
# Font sizes shared across projection-aware plot functions.
_TITLE_SIZE = 13 # suptitle and per-axes titles
_LABEL_SIZE = 11 # axis / colorbar labels
_TICK_SIZE = 11 # tick labels

# Annotations become unreadable when cells are smaller than this (in points)
# or when the total number of cells exceeds a readable count.
_MIN_CELL_SIZE_FOR_ANNOTATIONS = 18
_MAX_CELLS_FOR_ANNOTATIONS = 800
_HEATMAP_CMAP = matplotlib.colors.LinearSegmentedColormap.from_list(
"error_heatmap_white_red",
["#ffffff", "#fee5d9", "#fcae91", "#fb6a4a", "#cb181d"],
)


def _tex_safe(s: str) -> str:
"""Escape TeX special characters in s if TeX rendering is currently active.
Expand All @@ -29,6 +41,132 @@ def _tex_safe(s: str) -> str:
return s


def _compute_heatmap_layout(n_rows: int, n_cols: int) -> dict[str, float]:
"""Choose figure and font sizes from the heatmap dimensions."""
max_dim = max(n_rows, n_cols)

# Size the figure so each cell gets ~0.8 x 0.5 inches; floor at 8 x 4.5.
fig_width = float(max(4.5 + 0.8 * n_cols, 8.0))
fig_height = float(max(2.5 + 0.5 * n_rows, 4.5))

# Approximate cell size in points (72 pt/inch) to decide whether
# in-cell annotations will be legible.
cell_w_pt = (fig_width / max(n_cols, 1)) * 72
cell_h_pt = (fig_height / max(n_rows, 1)) * 72
show_annotations = (
n_rows * n_cols <= _MAX_CELLS_FOR_ANNOTATIONS
and min(cell_w_pt, cell_h_pt) >= _MIN_CELL_SIZE_FOR_ANNOTATIONS
)

return {
"fig_width": fig_width,
"fig_height": fig_height,
"tick_label_size": float(np.clip(15.0 - 0.18 * max_dim, 9.0, 14.0)),
"annotation_size": float(np.clip(13.0 - 0.22 * max_dim, 5.0, 12.0)),
"title_size": float(np.clip(16.0 - 0.15 * max_dim, 9.0, 15.0)),
"x_tick_rotation": 45.0 if n_cols > 12 else 0.0,
"show_annotations": show_annotations,
}


def _get_heatmap_var_labels(datastore: BaseRegularGridDatastore) -> list[str]:
"""Build state-variable labels from datastore metadata."""
var_names = datastore.get_vars_names(category="state")
var_units = datastore.get_vars_units(category="state")
return [
_tex_safe(f"{name} ({unit})" if unit else name)
for name, unit in zip(var_names, var_units)
]


def _to_heatmap_matrix(values) -> np.ndarray:
"""Convert `(pred_steps, d_f)` values to a `(d_f, pred_steps)` matrix."""
if hasattr(values, "detach"):
values = values.detach().cpu().numpy()
return np.asarray(values, dtype=float).T


def _get_feature_scale(
ds_stats: xr.Dataset, var_name: str, n_vars: int
) -> np.ndarray | None:
"""Extract a 1D per-feature scale, averaging over any extra dims."""
if var_name not in ds_stats:
return None

da_scale = ds_stats[var_name]
feature_dim = "state_feature"
if feature_dim not in da_scale.dims:
return None

reduce_dims = [dim for dim in da_scale.dims if dim != feature_dim]
if reduce_dims:
da_scale = da_scale.mean(dim=reduce_dims)

scale = np.asarray(da_scale.values, dtype=float).reshape(-1)
if scale.size < n_vars:
return None

return scale[:n_vars]


def _get_heatmap_color_values(
errors_np: np.ndarray, datastore: BaseRegularGridDatastore
) -> tuple[np.ndarray, str]:
"""Normalize heatmap colors to a cross-variable relative scale."""
try:
ds_state_stats = datastore.get_standardization_dataarray(
category="state"
)
except (AttributeError, KeyError, TypeError, ValueError) as exc:
warnings.warn(
f"Could not load standardization stats ({exc}); "
"falling back to absolute scale.",
UserWarning,
stacklevel=3,
)
return errors_np, "Absolute scale"

n_vars = errors_np.shape[0]
state_std = _get_feature_scale(ds_state_stats, "state_std", n_vars)
if state_std is None:
warnings.warn(
"State standardization stats are unavailable; "
"falling back to absolute scale.",
UserWarning,
stacklevel=3,
)
return errors_np, "Absolute scale"

scale = state_std
colorbar_label = "Relative scale (state stds)"

state_diff_std_standardized = _get_feature_scale(
ds_state_stats, "state_diff_std_standardized", n_vars
)
if state_diff_std_standardized is not None:
scale = scale * state_diff_std_standardized
colorbar_label = "Error / Std(1-step change)"

safe_scale = np.where(
np.isfinite(scale) & (np.abs(scale) > np.finfo(float).eps),
scale,
1.0,
)
return errors_np / safe_scale[:, None], colorbar_label


def _get_annotation_text_color(
value: float, image: matplotlib.image.AxesImage
) -> str:
"""Choose a readable annotation color from the rendered background."""
if not np.isfinite(value):
return "black"

rgba = image.cmap(image.norm(value))
luminance = 0.2126 * rgba[0] + 0.7152 * rgba[1] + 0.0722 * rgba[2]
return "white" if luminance < 0.5 else "black"


def plot_on_axis(
ax,
da,
Expand All @@ -40,7 +178,7 @@ def plot_on_axis(
boundary_alpha=None,
crop_to_interior=False,
):
"""Plot weather state on given axis using datastore metadata.
"""Plot weather state on a projection-aware axis using datastore metadata.

Parameters
----------
Expand All @@ -67,9 +205,7 @@ def plot_on_axis(
-------
matplotlib.collections.QuadMesh
The mesh object created by pcolormesh.

"""

ax.coastlines(resolution="50m")
ax.add_feature(cfeature.BORDERS, linestyle="-", alpha=0.5)

Expand Down Expand Up @@ -154,63 +290,120 @@ def plot_on_axis(
return mesh


# rc_context applies NeurIPS font/text settings; figure size is overridden
# by the explicit figsize= below but font family and usetex stay in effect.
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):
def plot_error_heatmap(
Comment on lines 351 to +352
Copy link
Copy Markdown
Collaborator

@sadamov sadamov Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The figure size set here is immediately overridden by figsize= below, so this looks like dead code. The context applies NeurIPS font/text settings - add a one-liner to make that clear:

Suggested change
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):
def plot_error_heatmap(
# rc_context applies NeurIPS font/text settings; figure size is overridden
# by the explicit figsize= below but font family and usetex stay in effect.
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_error_heatmap(

errors,
datastore: BaseRegularGridDatastore,
title=None,
):
"""
Plot a heatmap of errors of different variables at different
predictions horizons
errors: (pred_steps, d_f)
Plot a heatmap of errors for state variables across forecast lead times.

Parameters
----------
errors : torch.Tensor
Error values with shape `(pred_steps, d_f)`.
datastore : BaseRegularGridDatastore
Datastore providing step length and variable metadata.
title : str, optional
Optional title for the figure.
"""
errors_np = errors.T.cpu().numpy() # (d_f, pred_steps)
errors_np = _to_heatmap_matrix(errors)
d_f, pred_steps = errors_np.shape
step_length = datastore.step_length

# Normalize all errors to [0,1] for color map
max_errors = errors_np.max(axis=1) # d_f
errors_norm = errors_np / np.expand_dims(max_errors, axis=1)

time_step_int, time_step_unit = utils.get_integer_time(step_length)
layout = _compute_heatmap_layout(n_rows=d_f, n_cols=pred_steps)
color_values_np, colorbar_label = _get_heatmap_color_values(
errors_np, datastore
)

fig, ax = plt.subplots(figsize=(15, 10))
finite_color_values = color_values_np[np.isfinite(color_values_np)]
if finite_color_values.size == 0:
vmin, vmax = 0.0, 1.0
else:
vmin = float(finite_color_values.min())
vmax = float(finite_color_values.max())
if vmin >= 0.0:
vmin = 0.0
if np.isclose(vmin, vmax):
vmax = vmin + 1.0

ax.imshow(
errors_norm,
cmap="OrRd",
vmin=0,
vmax=1.0,
fig, ax = plt.subplots(
figsize=(layout["fig_width"], layout["fig_height"]),
constrained_layout=True,
)

im = ax.imshow(
color_values_np,
cmap=_HEATMAP_CMAP,
vmin=vmin,
vmax=vmax,
interpolation="none",
aspect="auto",
alpha=0.8,
)
cbar = fig.colorbar(im, ax=ax, pad=0.02)
cbar.set_label(_tex_safe(colorbar_label), size=layout["tick_label_size"])
cbar.ax.tick_params(labelsize=layout["tick_label_size"])
cbar.ax.yaxis.get_offset_text().set_fontsize(layout["tick_label_size"])

if layout["show_annotations"]:
for (j, i), error in np.ndenumerate(errors_np):
if np.isfinite(error):
formatted_error = (
f"{error:.3g}" if abs(error) < 1.0e4 else f"{error:.2E}"
)
else:
formatted_error = str(error)
text_color = _get_annotation_text_color(color_values_np[j, i], im)
ax.text(
i,
j,
formatted_error,
ha="center",
va="center",
usetex=False,
fontsize=layout["annotation_size"],
color=text_color,
)

# ax and labels
for (j, i), error in np.ndenumerate(errors_np):
# Numbers > 9999 will be too large to fit
formatted_error = f"{error:.3f}" if error < 9999 else f"{error:.2E}"
ax.text(i, j, formatted_error, ha="center", va="center", usetex=False)

# Ticks and labels
ax.set_xticks(np.arange(pred_steps))
pred_hor_i = np.arange(pred_steps) + 1
pred_hor_h = time_step_int * pred_hor_i
ax.set_xticklabels(pred_hor_h, size=_TICK_SIZE)
ax.set_xlabel(f"Lead time ({time_step_unit[0]})", size=_LABEL_SIZE)
ax.set_xticklabels(
pred_hor_h,
size=layout["tick_label_size"],
rotation=layout["x_tick_rotation"],
ha="right" if layout["x_tick_rotation"] > 0 else "center",
)
ax.set_xlabel(
f"Lead time ({time_step_unit[0]})", size=layout["tick_label_size"]
)

ax.set_yticks(np.arange(d_f))
var_names = datastore.get_vars_names(category="state")
var_units = datastore.get_vars_units(category="state")
y_ticklabels = [
_tex_safe(f"{name} ({unit})")
for name, unit in zip(var_names, var_units)
]
ax.set_yticklabels(y_ticklabels, rotation=30, size=_TICK_SIZE)
ax.set_yticklabels(
_get_heatmap_var_labels(datastore=datastore),
size=layout["tick_label_size"],
)

if title:
ax.set_title(title, size=_TITLE_SIZE)
ax.set_title(title, size=layout["title_size"])

return fig


def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):
"""Deprecated: use :func:`plot_error_heatmap` instead."""
warnings.warn(
"plot_error_map is deprecated, use plot_error_heatmap instead",
DeprecationWarning,
stacklevel=2,
)
return plot_error_heatmap(errors, datastore=datastore, title=title)


@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_prediction(
datastore: BaseRegularGridDatastore,
Expand Down
Loading
Loading