Skip to content

Improve metric heatmap readability and scaling (#375)#376

Open
kshirajahere wants to merge 22 commits intomllam:mainfrom
kshirajahere:fix/-#375-Improvements-to-plot_error_map-function
Open

Improve metric heatmap readability and scaling (#375)#376
kshirajahere wants to merge 22 commits intomllam:mainfrom
kshirajahere:fix/-#375-Improvements-to-plot_error_map-function

Conversation

@kshirajahere
Copy link
Copy Markdown
Contributor

@kshirajahere kshirajahere commented Mar 10, 2026

Describe your changes

This PR improves the metric heatmap plotting used during evaluation.

Summary of changes

  • renamed plot_error_map to plot_error_heatmap
  • updated the ARModel call site and related wording to use “heatmap” consistently
  • replaced per-row normalization with a single shared color scale across all variables
  • switched to a perceptually uniform colormap
  • added a colorbar for absolute-value interpretation
  • made figure size adapt to the number of variables and lead times
  • made tick-label and in-cell annotation font sizes adapt to heatmap dimensions
  • rotate x tick labels automatically for dense lead-time grids
  • added focused plotting regression tests for:
    • shared global color scaling
    • adaptive figure/font sizing behavior
  • added a changelog entry under unreleased -> Fixed

Motivation / context

Fixes the three issues raised in #375:

  • misleading function naming
  • heatmap colors not being comparable across variables
  • non-adaptive text sizing/layout for dense heatmaps

Dependencies

No new external dependencies.

Issue Link

Closes #375

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug
    • maintenance: when your contribution is relates to repo maintenance, e.g. CI/CD or documentation

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • (if the PR is not just maintenance/bugfix) the PR is assigned to the next milestone. If it is not, propose it for a future milestone.
  • author has added an entry to the changelog (and designated the change as added, changed, fixed or maintenance)
  • Once the PR is ready to be merged, squash commits and merge the PR.

Tests

Ran locally:

  • python -m pytest -q tests/test_plotting.py -k "heatmap" -> passed
  • python -m py_compile neural_lam/vis.py neural_lam/models/ar_model.py tests/test_plotting.py -> passed

Notes:

  • python -m pytest -q tests/test_plotting.py hits an existing cartopy/Natural Earth data path permission issue on this machine when the older integration plotting tests try to access coastline data. The new focused heatmap tests pass.

Refactor error map plotting function to heatmap and improve layout computation.
Added a HeatmapDatastore class for testing heatmap plots and updated tests to verify heatmap behavior.
Scale metric heatmap figure size and improve readability for larger outputs.
@kshirajahere kshirajahere force-pushed the fix/-#375-Improvements-to-plot_error_map-function branch from 7438177 to 4069d62 Compare March 10, 2026 21:11
Sir-Sloth-The-Lazy added a commit to Sir-Sloth-The-Lazy/neural-lam that referenced this pull request Mar 11, 2026
…cale control

- expose optional vmin/vmax args so callers can fix a shared colour
  scale across multiple heatmaps (e.g. val vs test comparison)
- propagate params through the plot_error_map deprecation wrapper
- add tests: explicit vmin/vmax respected, cross-run scale consistency
- update docstring to document new parameters

Follow-up to mllam#376
Sir-Sloth-The-Lazy added a commit to Sir-Sloth-The-Lazy/neural-lam that referenced this pull request Mar 11, 2026
…cale control

- expose optional vmin/vmax args so callers can fix a shared colour
  scale across multiple heatmaps (e.g. val vs test comparison)
- propagate params through the plot_error_map deprecation wrapper
- add tests: explicit vmin/vmax respected, cross-run scale consistency
- update docstring to document new parameters

Follow-up to mllam#376
@sahilkr31
Copy link
Copy Markdown

Hi @sadamov, and thanks @kshirajahere for the PR.

I went through the changes in this PR and the switch to a shared global color scale makes sense, since it allows the heatmap colours to reflect absolute error magnitudes and makes comparisons across variables easier.

One thing I was thinking about regarding point 2 is whether global scaling might sometimes reduce the contrast for variables whose error range is much smaller than others. In those cases a row could end up looking almost uniform even if there are differences across lead times.

Would it make sense to support both behaviours (for example using global scaling by default but optionally allowing per-variable normalization when someone wants to highlight within-variable patterns), or do you think keeping a single global scale is preferable for the evaluation plots here?

Curious to hear your thoughts also.

@Joltsy10
Copy link
Copy Markdown

The contrast problem @sahilkr31 points out exists because raw absolute errors across variables with different physical units aren't comparable to begin with. 500 Pa of surface pressure error and 1 K of temperature error sit on completely different scales, so a shared raw colormap will always be dominated by high-magnitude variables regardless of how you normalize it.

Normalizing each variable by its climatological standard deviation before applying the global scale fixes this properly. The colormap then encodes normalized error, how large is this error relative to natural variability, which is actually consistent across variables. Same color, same forecast skill, regardless of units.

The per-variable toggle is just letting the user pick which misleading view they want rather than solving the underlying issue.😄
Happy to be corrected if raw absolute errors are preferred here for a specific reason.

@sadamov
Copy link
Copy Markdown
Collaborator

sadamov commented Mar 12, 2026

500 Pa of surface pressure error and 1 K of temperature error sit on completely different scales, so a shared raw colormap will always be dominated by high-magnitude variables regardless of how you normalize it.
Normalizing each variable by its climatological standard deviation before applying the global scale fixes this properly. The colormap then encodes normalized error, how large is this error relative to natural variability, which is actually consistent across variables. Same color, same forecast skill, regardless of units.

Agreed, this sounds like a good way to do it. We can implement this cleanly by expanding plot_error_map to accept both physical and normalized errors separately. Physical values are always shown as text annotations on the heatmap (unchanged). When errors_norm is also provided, it drives the colormap instead of the default per-variable-max normalization:

def plot_error_map(
    errors,
    datastore: BaseRegularGridDatastore,
    title=None,
    errors_norm=None,
):
    """
    errors: (pred_steps, d_f) — physical-unit values, shown as text annotations
    errors_norm: (pred_steps, d_f) — optional values used for colormap scaling.
        When provided (e.g. errors divided by climatological std), the colormap
        encodes skill relative to natural variability, making colors comparable
        across variables with different units. When None, each variable is
        normalized by its own maximum error.
    """

The caller in ar_model passes metric_tensor_averaged (already in normalized units) as errors_norm and metric_tensor_averaged * state_std as errors, so no redundant rescaling occurs.

@sadamov sadamov self-requested a review March 12, 2026 20:53
@sadamov sadamov self-assigned this Mar 12, 2026
@sadamov sadamov added the enhancement New feature or request label Mar 12, 2026
@sadamov
Copy link
Copy Markdown
Collaborator

sadamov commented Mar 12, 2026

@kshirajahere could you share some pictures of the new heatmap for low/high number of vars and lead_times?

@kshirajahere
Copy link
Copy Markdown
Contributor Author

kshirajahere commented Mar 13, 2026

@sadamov Sure
massive_grid_120leads_40vars
full_real_subset_8vars_65leads

you will need to open in new tab to see how it comes now, without overlapping, i tried to make this test as rigorous as possible.
Some simpler heatmaps-->
small_real_subset_5vars_6leads
medium_real_subset_8vars_20leads

I've generated some sample heatmaps using the current code from the PR. I used the MEPS sample data to showcase real-world variable structures along with a simulated "massive grid" scenario like the one mentioned in the issue.

@kshirajahere
Copy link
Copy Markdown
Contributor Author

kshirajahere commented Mar 13, 2026

@sahilkr31 @Joltsy10 @sadamov Thanks for the great discussion here. I agree that comparing raw values with varying physical units like Pressure vs Temp can misrepresent skill when they share a single color scale.

Based on this, I've ensured the heatmap code correctly handles whatever normalization metric is passed to it by the ARModel. The plotting function now uses a global color scale for the matrix provided, meaning that if it receives climatology-normalized errors from the training/evaluation loop (where natural variability is standardized), the heatmap's color mapping will automatically reflect unit-independent forecast skill! And, of course, the annotations continue to show the absolute numeric values properly, avoiding overlapping and adjusting depending on cell size.

@joeloskarsson joeloskarsson self-requested a review March 13, 2026 21:42
@joeloskarsson
Copy link
Copy Markdown
Collaborator

Hi, thanks for working on this, and great that you raised the issue @sadamov. These plots have often looked quite bad.

I was a bit worried by the discussion of using a joint colormap, since (as then later pointed out) this would not look reasonable at all due to the magnitude of variables being so different. I like the current idea though, to have the colormap based on error in standardized units and then always the original numbers in text. From the examples it looks like some variables still dominate though. Might need to mix in the variance of the single-step-delta in the standardization to get something reasonable.

Personally I am not much of a fun of the switch to a perceptually uniform colormap for this. The change from white to red clearly shows that things got worse, and with viridis one has to consult the color bar to recall e.g. that lighter is worse and that green is not good. But I understand this is a bit personal preference.

@kshirajahere
Copy link
Copy Markdown
Contributor Author

Thanks Joel @joeloskarsson
The latest update keeps the numeric annotations in the original units, but changes the background colors to use a relative cross-variable scale instead of raw physical magnitudes. Please review @sadamov @joeloskarsson

Concretely:

  • the heatmap now normalizes the background colors using datastore standardization stats,
  • and when state_diff_std_standardized is available it uses that as well, so the colors reflect magnitude relative to typical 1-step variability for each variable,
  • while the cell annotations still show the original rescaled values.

I also switched the colormap away from viridis to a white-to-red sequential scale, since that seemed more in line with your point that “worse” should read visually as red without needing to decode the colorbar.

I added tests for:

  • relative cross-variable color scaling,
  • the white-to-red colormap behavior,
  • the adaptive layout/font sizing,
  • dense-grid annotation suppression,
  • and the deprecated wrapper.

I’ve attached:

  1. a raw-global-scale MEPS-style baseline,
image
  1. the updated relative-scale version on the same matrix,
image
  1. and a dense-grid readability example.
image

Sir-Sloth-The-Lazy added a commit to Sir-Sloth-The-Lazy/neural-lam that referenced this pull request Mar 18, 2026
…cale control

- expose optional vmin/vmax args so callers can fix a shared colour
  scale across multiple heatmaps (e.g. val vs test comparison)
- propagate params through the plot_error_map deprecation wrapper
- add tests: explicit vmin/vmax respected, cross-run scale consistency
- update docstring to document new parameters

Follow-up to mllam#376
@sadamov sadamov added this to the v0.7.0 (proposed) milestone Mar 21, 2026
Copy link
Copy Markdown
Collaborator

@sadamov sadamov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost good to go — a few small fixes needed before merge.

Comment thread neural_lam/vis.py Outdated
return {
"fig_width": fig_width,
"fig_height": fig_height,
"tick_label_size": float(np.clip(15.0 - 0.18 * max_dim, 7.0, 14.0)),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 pt is at the edge of readability for large heatmaps (e.g. 30 vars × 40 steps). Suggest raising the floor to 9 pt.

Suggested change
"tick_label_size": float(np.clip(15.0 - 0.18 * max_dim, 7.0, 14.0)),
"tick_label_size": float(np.clip(15.0 - 0.18 * max_dim, 9.0, 14.0)),

Comment thread neural_lam/vis.py Outdated
category="state"
)
except (AttributeError, KeyError, TypeError, ValueError):
return errors_np, "Relative scale"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When standardization stats are unavailable the fallback returns raw error values, which are on an absolute scale — not a relative one. The label is misleading.

Suggested change
return errors_np, "Relative scale"
return errors_np, "Absolute scale"

Comment thread neural_lam/vis.py Outdated
n_vars = errors_np.shape[0]
state_std = _get_feature_scale(ds_state_stats, "state_std", n_vars)
if state_std is None:
return errors_np, "Relative scale"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as above.

Suggested change
return errors_np, "Relative scale"
return errors_np, "Absolute scale"

Comment thread neural_lam/vis.py Outdated
)
if state_diff_std_standardized is not None:
scale = scale * state_diff_std_standardized
colorbar_label = "Relative scale (1-step diff stds)"
Copy link
Copy Markdown
Collaborator

@sadamov sadamov Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The label is correct but i would prefer error / σ(1-step change) where σ(1-step change) is state_std × state_diff_std_standardized — the std of 1-step differences in physical units, i.e. roughly the error a persistence forecast makes after one step. A value of 1.0 means the model error equals a typical hour-to-hour change; 0.5 means twice as accurate as persistence.

Suggested change
colorbar_label = "Relative scale (1-step diff stds)"
colorbar_label = "Error / σ(1-step change)"

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you prefer Std instead of σ that's also great :)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or (1-step difference) works, too. Up to you!

Comment thread neural_lam/vis.py Outdated

"""

"""Plot weather state on a projection-aware axis using datastore metadata."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original plot_on_axis docstring had a full NumPy-style Parameters/Returns section that is used by IDEs and documentation generators. The PR replaced it with a single-line summary unrelated to the PR goals — please restore it.

Suggested change
"""Plot weather state on a projection-aware axis using datastore metadata."""
"""Plot weather state on a projection-aware axis using datastore metadata.
Parameters
----------
ax : matplotlib.axes.Axes
The axis to plot on. Should have a cartopy projection.
da : xarray.DataArray
The data to plot. Should have shape (N_grid,).
datastore : BaseRegularGridDatastore
The datastore containing metadata about the grid.
vmin : float, optional
Minimum value for color scale.
vmax : float, optional
Maximum value for color scale.
ax_title : str, optional
Title for the axis.
cmap : str or matplotlib.colors.Colormap, optional
Colormap to use for plotting.
boundary_alpha : float, optional
If provided, overlay boundary mask with given alpha transparency.
crop_to_interior : bool, optional
If True, crop the plot to the interior region.
Returns
-------
matplotlib.collections.QuadMesh
The mesh object created by pcolormesh.
"""

Comment thread neural_lam/vis.py Outdated
Comment on lines +264 to +307
color_values=None,
colorbar_label=None,
):
"""
Plot a heatmap of errors of different variables at different
predictions horizons
errors: (pred_steps, d_f)
Plot a heatmap of errors for state variables across forecast lead times.

Parameters
----------
errors : torch.Tensor
Error values with shape `(pred_steps, d_f)`.
datastore : BaseRegularGridDatastore
Datastore providing step length and variable metadata.
title : str, optional
Optional title for the figure.
color_values : torch.Tensor, optional
Optional values used only for the background colors. If omitted,
colors are normalized from ``errors`` using datastore state-variable
standardization statistics.
colorbar_label : str, optional
Optional label for the colorbar. If omitted, an automatic label is
chosen based on the color normalization used.
"""
errors_np = errors.T.cpu().numpy() # (d_f, pred_steps)
errors_np = _to_heatmap_matrix(errors)
d_f, pred_steps = errors_np.shape
step_length = datastore.step_length

# Normalize all errors to [0,1] for color map
max_errors = errors_np.max(axis=1) # d_f
errors_norm = errors_np / np.expand_dims(max_errors, axis=1)

time_step_int, time_step_unit = utils.get_integer_time(step_length)
layout = _compute_heatmap_layout(n_rows=d_f, n_cols=pred_steps)

fig, ax = plt.subplots(figsize=(15, 10))
if color_values is None:
color_values_np, default_colorbar_label = _get_heatmap_color_values(
errors_np, datastore
)
else:
color_values_np = _to_heatmap_matrix(color_values)
default_colorbar_label = "Relative scale"
if color_values_np.shape != errors_np.shape:
raise ValueError(
"color_values must have the same shape as errors: "
f"got {color_values_np.T.shape} and {errors_np.T.shape}"
)

ax.imshow(
errors_norm,
cmap="OrRd",
vmin=0,
vmax=1.0,
if colorbar_label is None:
colorbar_label = default_colorbar_label
Copy link
Copy Markdown
Collaborator

@sadamov sadamov Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

color_values and colorbar_label are never used at the only call site (ar_model.py) and add ~20 lines for zero current benefit. Three small removals needed:

1. Remove the two params from the signature (lines 264–265):

# delete these two lines
    color_values=None,
    colorbar_label=None,

2. Remove their docstring entries (lines 278–284):

# delete these seven lines
    color_values : torch.Tensor, optional
        Optional values used only for the background colors. If omitted,
        colors are normalized from ``errors`` using datastore state-variable
        standardization statistics.
    colorbar_label : str, optional
        Optional label for the colorbar. If omitted, an automatic label is
        chosen based on the color normalization used.

3. Replace the if/else block (lines 293–307) with the direct call:

    color_values_np, colorbar_label = _get_heatmap_color_values(
        errors_np, datastore
    )

@kshirajahere
Copy link
Copy Markdown
Contributor Author

kshirajahere commented Mar 25, 2026

@sadamov Thanks for the review, I will fix these and get back to you ASAP :)

@sadamov
Copy link
Copy Markdown
Collaborator

sadamov commented Mar 25, 2026

Make sure to also run pre-commits (and pytest) locally before the next commit 🙏

@kshirajahere
Copy link
Copy Markdown
Contributor Author

I reran the plotting tests locally on the latest branch state before pushing, and tests/test_plotting.py -q passes on my side. The current push also has all pre-commit jobs green.
@sadamov Ready for a re-review :)

@kshirajahere kshirajahere requested a review from sadamov March 25, 2026 14:42
Copy link
Copy Markdown
Collaborator

@sadamov sadamov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lookin good just a few minor things to fix before merge

Comment thread neural_lam/vis.py Outdated
Comment thread neural_lam/vis.py Outdated
Comment on lines +75 to +80
if len(var_names) < n_vars:
var_names.extend(
[f"state_feature_{i}" for i in range(len(var_names), n_vars)]
)
if len(var_units) < n_vars:
var_units.extend([""] * (n_vars - len(var_units)))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_vars always equals errors.shape[0] (from the same datastore that provides var names), so the padding block can never trigger. Simplify to:

Suggested change
if len(var_names) < n_vars:
var_names.extend(
[f"state_feature_{i}" for i in range(len(var_names), n_vars)]
)
if len(var_units) < n_vars:
var_units.extend([""] * (n_vars - len(var_units)))
var_names = datastore.get_vars_names(category="state")
var_units = datastore.get_vars_units(category="state")
return [
_tex_safe(f"{n} ({u})" if u else n)
for n, u in zip(var_names, var_units)
]

Comment thread neural_lam/vis.py Outdated
Comment thread neural_lam/vis.py Outdated
Comment thread neural_lam/vis.py
Comment on lines 286 to +287
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):
def plot_error_heatmap(
Copy link
Copy Markdown
Collaborator

@sadamov sadamov Mar 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The figure size set here is immediately overridden by figsize= below, so this looks like dead code. The context applies NeurIPS font/text settings - add a one-liner to make that clear:

Suggested change
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):
def plot_error_heatmap(
# rc_context applies NeurIPS font/text settings; figure size is overridden
# by the explicit figsize= below but font family and usetex stay in effect.
@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_error_heatmap(

Comment thread tests/test_plotting.py Outdated
Comment on lines +243 to +271
def test_plot_error_heatmap_adapts_figure_and_font_sizes():
"""Dense heatmaps should get more space and smaller text."""
small_errors = torch.ones((4, 5))
large_errors = torch.ones((20, 30))

small_fig = vis.plot_error_heatmap(
small_errors, datastore=HeatmapDatastore(n_vars=small_errors.shape[1])
)
large_fig = vis.plot_error_heatmap(
large_errors, datastore=HeatmapDatastore(n_vars=large_errors.shape[1])
)

small_ax = small_fig.axes[0]
large_ax = large_fig.axes[0]

assert large_fig.get_size_inches()[0] > small_fig.get_size_inches()[0]
assert large_fig.get_size_inches()[1] > small_fig.get_size_inches()[1]
assert (
large_ax.get_yticklabels()[0].get_fontsize()
< small_ax.get_yticklabels()[0].get_fontsize()
)
assert large_ax.texts[0].get_fontsize() < small_ax.texts[0].get_fontsize()
assert large_ax.get_xticklabels()[0].get_rotation() == 45.0

plt.close(small_fig)
plt.close(large_fig)


def test_plot_error_heatmap_skips_annotations_for_very_dense_grids():
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_plot_error_heatmap_adapts_figure_and_font_sizes and test_plot_error_heatmap_skips_annotations_for_very_dense_grids share the same setup and exercise the same _compute_heatmap_layout path. Merge into one:

Suggested change
def test_plot_error_heatmap_adapts_figure_and_font_sizes():
"""Dense heatmaps should get more space and smaller text."""
small_errors = torch.ones((4, 5))
large_errors = torch.ones((20, 30))
small_fig = vis.plot_error_heatmap(
small_errors, datastore=HeatmapDatastore(n_vars=small_errors.shape[1])
)
large_fig = vis.plot_error_heatmap(
large_errors, datastore=HeatmapDatastore(n_vars=large_errors.shape[1])
)
small_ax = small_fig.axes[0]
large_ax = large_fig.axes[0]
assert large_fig.get_size_inches()[0] > small_fig.get_size_inches()[0]
assert large_fig.get_size_inches()[1] > small_fig.get_size_inches()[1]
assert (
large_ax.get_yticklabels()[0].get_fontsize()
< small_ax.get_yticklabels()[0].get_fontsize()
)
assert large_ax.texts[0].get_fontsize() < small_ax.texts[0].get_fontsize()
assert large_ax.get_xticklabels()[0].get_rotation() == 45.0
plt.close(small_fig)
plt.close(large_fig)
def test_plot_error_heatmap_skips_annotations_for_very_dense_grids():
def test_plot_error_heatmap_adapts_layout_for_grid_size():
"""Dense heatmaps get larger figures, smaller fonts, and suppress annotations."""
small_fig = vis.plot_error_heatmap(torch.ones((4, 5)), datastore=HeatmapDatastore(n_vars=5))
large_fig = vis.plot_error_heatmap(torch.ones((20, 30)), datastore=HeatmapDatastore(n_vars=30))
dense_fig = vis.plot_error_heatmap(torch.ones((40, 50)), datastore=HeatmapDatastore(n_vars=50))
assert large_fig.get_size_inches()[0] > small_fig.get_size_inches()[0]
assert large_fig.axes[0].get_yticklabels()[0].get_fontsize() < small_fig.axes[0].get_yticklabels()[0].get_fontsize()
assert large_fig.axes[0].get_xticklabels()[0].get_rotation() == 45.0
assert len(dense_fig.axes[0].texts) == 0
assert dense_fig.get_size_inches()[0] > 18.0
plt.close(small_fig); plt.close(large_fig); plt.close(dense_fig)

Comment thread tests/test_plotting.py
Comment thread tests/test_plotting.py
Comment thread tests/test_plotting.py
Comment thread CHANGELOG.md Outdated
@kshirajahere
Copy link
Copy Markdown
Contributor Author

Hey @sadamov addressed latest review changes, ready for review.

@kshirajahere
Copy link
Copy Markdown
Contributor Author

Also ran all tests locally before

@kshirajahere kshirajahere requested a review from sadamov March 28, 2026 13:24
Copy link
Copy Markdown
Collaborator

@sadamov sadamov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like these new maps! they will prove very useful. This PR is proposed for v0.7.0 and will be discussed at the next dev meeting and is marked as ready. @kshirajahere thanks for working on this over the course of the past weeks!

@sadamov sadamov added the ready Review complete - proposed for milestone label Apr 9, 2026
@sadamov sadamov changed the title fix: improve metric heatmap readability and scaling (Closes #375) Improve metric heatmap readability and scaling (#375) Apr 13, 2026
@sadamov
Copy link
Copy Markdown
Collaborator

sadamov commented Apr 15, 2026

@kshirajahere I saw you added the docstrings asked for during the dev meeting. great! now we wait for #208

sadamov and others added 4 commits April 16, 2026 10:25
…parameter

Adds `normalization: str = "state_std"` to `plot_error_heatmap` and
rewrites `_get_heatmap_color_values` to support two explicit modes:
- `"state_std"`: RMSE / state_std, white-to-red Reds colormap
- `"one_step"`: RMSE / diff_std, white-to-red Reds colormap

Both modes fall back to per-variable max (also Reds, [0,1]) when their
required stat is missing. The two modes never silently upgrade to each
other. Removes the custom `_HEATMAP_CMAP` in favour of the built-in
`"Reds"` colormap string. Updates tests accordingly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@sadamov
Copy link
Copy Markdown
Collaborator

sadamov commented Apr 17, 2026

Hi @kshirajahere,

After the dev-meeting I took another look at some realistic heatmaps using the COSMO checkpoint from the boundary paper. And as @joeloskarsson mentioned, the heatmaps are really dominated by a few variables. Also, I don't think that the three-level silent fallback can be a bit dangeours as the user might not get the heatmap they expected. Because we already completed the PR review I implemented the changes here for you, not to cause you more work. But I would appreciate your review and if you agree merge with #376. Here are four plots for four different colormaps based on different scalings/normalizations (the skill one is not part of the PR, because we said we want only red tones 🔴 )

cosmo_skill_heatmap.pdf (omitted from PR)
cosmo_per_var_max_heatmap.pdf (legacy fallback)
cosmo_one_step_std_heatmap.pdf (optional via new argument)
cosmo_global_state_std_heatmap.pdf (default)

In this PR here we had one step diff as the default. But as you can see this is basically dominated by the pressure varaibles that change only very slightly in one forcast step. Sea Level pressure is actually a very important variable to assess model skill also in traditional NWP models. It contains the information about the whole atmospheric column and therefore about the layering and horizontal gradients -> wind.
But, I would argue that just seeing "yes pressure is important" is not what these score cards are used for (I am exaggerating a bit). Rather we would like to see a somewhat useful comparison between variables. And the the global state diff allows for that. We actually see the typical dynamic (i.e. fast changing) variables to be the most red, like wind, vertical motion, humidity.

If you agree with my logic here and the PR into your branch, feel free to merge. Or let me know if you see some other reason why my default is bad. kshirajahere#1

@kshirajahere
Copy link
Copy Markdown
Contributor Author

Hey @sadamov went through the stacked PR. It looks good from what I have reviewed. (just few questions before merging, have raised them on kshirajahere#1)

This feels like a cleaner match to the dev-meeting discussion and should make the heatmap behavior easier to reason about from the API and the colorbar label. Thanks a lot ❤️

@kshirajahere
Copy link
Copy Markdown
Contributor Author

kshirajahere commented Apr 17, 2026

Merged kshirajahere#1 into this PR, Thanks @sadamov :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request ready Review complete - proposed for milestone

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Improve metric heatmap readability and scaling

5 participants