Improve metric heatmap readability and scaling (#375)#376
Improve metric heatmap readability and scaling (#375)#376kshirajahere wants to merge 22 commits intomllam:mainfrom
Conversation
Refactor error map plotting function to heatmap and improve layout computation.
Added a HeatmapDatastore class for testing heatmap plots and updated tests to verify heatmap behavior.
Scale metric heatmap figure size and improve readability for larger outputs.
…ense grid annotation suppression
7438177 to
4069d62
Compare
…cale control - expose optional vmin/vmax args so callers can fix a shared colour scale across multiple heatmaps (e.g. val vs test comparison) - propagate params through the plot_error_map deprecation wrapper - add tests: explicit vmin/vmax respected, cross-run scale consistency - update docstring to document new parameters Follow-up to mllam#376
…cale control - expose optional vmin/vmax args so callers can fix a shared colour scale across multiple heatmaps (e.g. val vs test comparison) - propagate params through the plot_error_map deprecation wrapper - add tests: explicit vmin/vmax respected, cross-run scale consistency - update docstring to document new parameters Follow-up to mllam#376
|
Hi @sadamov, and thanks @kshirajahere for the PR. I went through the changes in this PR and the switch to a shared global color scale makes sense, since it allows the heatmap colours to reflect absolute error magnitudes and makes comparisons across variables easier. One thing I was thinking about regarding point 2 is whether global scaling might sometimes reduce the contrast for variables whose error range is much smaller than others. In those cases a row could end up looking almost uniform even if there are differences across lead times. Would it make sense to support both behaviours (for example using global scaling by default but optionally allowing per-variable normalization when someone wants to highlight within-variable patterns), or do you think keeping a single global scale is preferable for the evaluation plots here? Curious to hear your thoughts also. |
|
The contrast problem @sahilkr31 points out exists because raw absolute errors across variables with different physical units aren't comparable to begin with. 500 Pa of surface pressure error and 1 K of temperature error sit on completely different scales, so a shared raw colormap will always be dominated by high-magnitude variables regardless of how you normalize it. Normalizing each variable by its climatological standard deviation before applying the global scale fixes this properly. The colormap then encodes normalized error, how large is this error relative to natural variability, which is actually consistent across variables. Same color, same forecast skill, regardless of units. The per-variable toggle is just letting the user pick which misleading view they want rather than solving the underlying issue.😄 |
Agreed, this sounds like a good way to do it. We can implement this cleanly by expanding plot_error_map to accept both physical and normalized errors separately. Physical values are always shown as text annotations on the heatmap (unchanged). When errors_norm is also provided, it drives the colormap instead of the default per-variable-max normalization: def plot_error_map(
errors,
datastore: BaseRegularGridDatastore,
title=None,
errors_norm=None,
):
"""
errors: (pred_steps, d_f) — physical-unit values, shown as text annotations
errors_norm: (pred_steps, d_f) — optional values used for colormap scaling.
When provided (e.g. errors divided by climatological std), the colormap
encodes skill relative to natural variability, making colors comparable
across variables with different units. When None, each variable is
normalized by its own maximum error.
"""The caller in ar_model passes metric_tensor_averaged (already in normalized units) as errors_norm and metric_tensor_averaged * state_std as errors, so no redundant rescaling occurs. |
|
@kshirajahere could you share some pictures of the new heatmap for low/high number of vars and lead_times? |
|
@sadamov Sure you will need to open in new tab to see how it comes now, without overlapping, i tried to make this test as rigorous as possible. I've generated some sample heatmaps using the current code from the PR. I used the MEPS sample data to showcase real-world variable structures along with a simulated "massive grid" scenario like the one mentioned in the issue. |
|
@sahilkr31 @Joltsy10 @sadamov Thanks for the great discussion here. I agree that comparing raw values with varying physical units like Pressure vs Temp can misrepresent skill when they share a single color scale. Based on this, I've ensured the heatmap code correctly handles whatever normalization metric is passed to it by the |
|
Hi, thanks for working on this, and great that you raised the issue @sadamov. These plots have often looked quite bad. I was a bit worried by the discussion of using a joint colormap, since (as then later pointed out) this would not look reasonable at all due to the magnitude of variables being so different. I like the current idea though, to have the colormap based on error in standardized units and then always the original numbers in text. From the examples it looks like some variables still dominate though. Might need to mix in the variance of the single-step-delta in the standardization to get something reasonable. Personally I am not much of a fun of the switch to a perceptually uniform colormap for this. The change from white to red clearly shows that things got worse, and with |
|
Thanks Joel @joeloskarsson Concretely:
I also switched the colormap away from I added tests for:
I’ve attached:
|
…cale control - expose optional vmin/vmax args so callers can fix a shared colour scale across multiple heatmaps (e.g. val vs test comparison) - propagate params through the plot_error_map deprecation wrapper - add tests: explicit vmin/vmax respected, cross-run scale consistency - update docstring to document new parameters Follow-up to mllam#376
sadamov
left a comment
There was a problem hiding this comment.
Almost good to go — a few small fixes needed before merge.
| return { | ||
| "fig_width": fig_width, | ||
| "fig_height": fig_height, | ||
| "tick_label_size": float(np.clip(15.0 - 0.18 * max_dim, 7.0, 14.0)), |
There was a problem hiding this comment.
7 pt is at the edge of readability for large heatmaps (e.g. 30 vars × 40 steps). Suggest raising the floor to 9 pt.
| "tick_label_size": float(np.clip(15.0 - 0.18 * max_dim, 7.0, 14.0)), | |
| "tick_label_size": float(np.clip(15.0 - 0.18 * max_dim, 9.0, 14.0)), |
| category="state" | ||
| ) | ||
| except (AttributeError, KeyError, TypeError, ValueError): | ||
| return errors_np, "Relative scale" |
There was a problem hiding this comment.
When standardization stats are unavailable the fallback returns raw error values, which are on an absolute scale — not a relative one. The label is misleading.
| return errors_np, "Relative scale" | |
| return errors_np, "Absolute scale" |
| n_vars = errors_np.shape[0] | ||
| state_std = _get_feature_scale(ds_state_stats, "state_std", n_vars) | ||
| if state_std is None: | ||
| return errors_np, "Relative scale" |
There was a problem hiding this comment.
Same issue as above.
| return errors_np, "Relative scale" | |
| return errors_np, "Absolute scale" |
| ) | ||
| if state_diff_std_standardized is not None: | ||
| scale = scale * state_diff_std_standardized | ||
| colorbar_label = "Relative scale (1-step diff stds)" |
There was a problem hiding this comment.
The label is correct but i would prefer error / σ(1-step change) where σ(1-step change) is state_std × state_diff_std_standardized — the std of 1-step differences in physical units, i.e. roughly the error a persistence forecast makes after one step. A value of 1.0 means the model error equals a typical hour-to-hour change; 0.5 means twice as accurate as persistence.
| colorbar_label = "Relative scale (1-step diff stds)" | |
| colorbar_label = "Error / σ(1-step change)" |
There was a problem hiding this comment.
If you prefer Std instead of σ that's also great :)
There was a problem hiding this comment.
Or (1-step difference) works, too. Up to you!
|
|
||
| """ | ||
|
|
||
| """Plot weather state on a projection-aware axis using datastore metadata.""" |
There was a problem hiding this comment.
The original plot_on_axis docstring had a full NumPy-style Parameters/Returns section that is used by IDEs and documentation generators. The PR replaced it with a single-line summary unrelated to the PR goals — please restore it.
| """Plot weather state on a projection-aware axis using datastore metadata.""" | |
| """Plot weather state on a projection-aware axis using datastore metadata. | |
| Parameters | |
| ---------- | |
| ax : matplotlib.axes.Axes | |
| The axis to plot on. Should have a cartopy projection. | |
| da : xarray.DataArray | |
| The data to plot. Should have shape (N_grid,). | |
| datastore : BaseRegularGridDatastore | |
| The datastore containing metadata about the grid. | |
| vmin : float, optional | |
| Minimum value for color scale. | |
| vmax : float, optional | |
| Maximum value for color scale. | |
| ax_title : str, optional | |
| Title for the axis. | |
| cmap : str or matplotlib.colors.Colormap, optional | |
| Colormap to use for plotting. | |
| boundary_alpha : float, optional | |
| If provided, overlay boundary mask with given alpha transparency. | |
| crop_to_interior : bool, optional | |
| If True, crop the plot to the interior region. | |
| Returns | |
| ------- | |
| matplotlib.collections.QuadMesh | |
| The mesh object created by pcolormesh. | |
| """ |
| color_values=None, | ||
| colorbar_label=None, | ||
| ): | ||
| """ | ||
| Plot a heatmap of errors of different variables at different | ||
| predictions horizons | ||
| errors: (pred_steps, d_f) | ||
| Plot a heatmap of errors for state variables across forecast lead times. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| errors : torch.Tensor | ||
| Error values with shape `(pred_steps, d_f)`. | ||
| datastore : BaseRegularGridDatastore | ||
| Datastore providing step length and variable metadata. | ||
| title : str, optional | ||
| Optional title for the figure. | ||
| color_values : torch.Tensor, optional | ||
| Optional values used only for the background colors. If omitted, | ||
| colors are normalized from ``errors`` using datastore state-variable | ||
| standardization statistics. | ||
| colorbar_label : str, optional | ||
| Optional label for the colorbar. If omitted, an automatic label is | ||
| chosen based on the color normalization used. | ||
| """ | ||
| errors_np = errors.T.cpu().numpy() # (d_f, pred_steps) | ||
| errors_np = _to_heatmap_matrix(errors) | ||
| d_f, pred_steps = errors_np.shape | ||
| step_length = datastore.step_length | ||
|
|
||
| # Normalize all errors to [0,1] for color map | ||
| max_errors = errors_np.max(axis=1) # d_f | ||
| errors_norm = errors_np / np.expand_dims(max_errors, axis=1) | ||
|
|
||
| time_step_int, time_step_unit = utils.get_integer_time(step_length) | ||
| layout = _compute_heatmap_layout(n_rows=d_f, n_cols=pred_steps) | ||
|
|
||
| fig, ax = plt.subplots(figsize=(15, 10)) | ||
| if color_values is None: | ||
| color_values_np, default_colorbar_label = _get_heatmap_color_values( | ||
| errors_np, datastore | ||
| ) | ||
| else: | ||
| color_values_np = _to_heatmap_matrix(color_values) | ||
| default_colorbar_label = "Relative scale" | ||
| if color_values_np.shape != errors_np.shape: | ||
| raise ValueError( | ||
| "color_values must have the same shape as errors: " | ||
| f"got {color_values_np.T.shape} and {errors_np.T.shape}" | ||
| ) | ||
|
|
||
| ax.imshow( | ||
| errors_norm, | ||
| cmap="OrRd", | ||
| vmin=0, | ||
| vmax=1.0, | ||
| if colorbar_label is None: | ||
| colorbar_label = default_colorbar_label |
There was a problem hiding this comment.
color_values and colorbar_label are never used at the only call site (ar_model.py) and add ~20 lines for zero current benefit. Three small removals needed:
1. Remove the two params from the signature (lines 264–265):
# delete these two lines
color_values=None,
colorbar_label=None,2. Remove their docstring entries (lines 278–284):
# delete these seven lines
color_values : torch.Tensor, optional
Optional values used only for the background colors. If omitted,
colors are normalized from ``errors`` using datastore state-variable
standardization statistics.
colorbar_label : str, optional
Optional label for the colorbar. If omitted, an automatic label is
chosen based on the color normalization used.3. Replace the if/else block (lines 293–307) with the direct call:
color_values_np, colorbar_label = _get_heatmap_color_values(
errors_np, datastore
)|
@sadamov Thanks for the review, I will fix these and get back to you ASAP :) |
|
Make sure to also run pre-commits (and pytest) locally before the next commit 🙏 |
|
I reran the plotting tests locally on the latest branch state before pushing, and |
sadamov
left a comment
There was a problem hiding this comment.
lookin good just a few minor things to fix before merge
| if len(var_names) < n_vars: | ||
| var_names.extend( | ||
| [f"state_feature_{i}" for i in range(len(var_names), n_vars)] | ||
| ) | ||
| if len(var_units) < n_vars: | ||
| var_units.extend([""] * (n_vars - len(var_units))) |
There was a problem hiding this comment.
n_vars always equals errors.shape[0] (from the same datastore that provides var names), so the padding block can never trigger. Simplify to:
| if len(var_names) < n_vars: | |
| var_names.extend( | |
| [f"state_feature_{i}" for i in range(len(var_names), n_vars)] | |
| ) | |
| if len(var_units) < n_vars: | |
| var_units.extend([""] * (n_vars - len(var_units))) | |
| var_names = datastore.get_vars_names(category="state") | |
| var_units = datastore.get_vars_units(category="state") | |
| return [ | |
| _tex_safe(f"{n} ({u})" if u else n) | |
| for n, u in zip(var_names, var_units) | |
| ] |
| @matplotlib.rc_context(utils.fractional_plot_bundle(1)) | ||
| def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): | ||
| def plot_error_heatmap( |
There was a problem hiding this comment.
The figure size set here is immediately overridden by figsize= below, so this looks like dead code. The context applies NeurIPS font/text settings - add a one-liner to make that clear:
| @matplotlib.rc_context(utils.fractional_plot_bundle(1)) | |
| def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): | |
| def plot_error_heatmap( | |
| # rc_context applies NeurIPS font/text settings; figure size is overridden | |
| # by the explicit figsize= below but font family and usetex stay in effect. | |
| @matplotlib.rc_context(utils.fractional_plot_bundle(1)) | |
| def plot_error_heatmap( |
| def test_plot_error_heatmap_adapts_figure_and_font_sizes(): | ||
| """Dense heatmaps should get more space and smaller text.""" | ||
| small_errors = torch.ones((4, 5)) | ||
| large_errors = torch.ones((20, 30)) | ||
|
|
||
| small_fig = vis.plot_error_heatmap( | ||
| small_errors, datastore=HeatmapDatastore(n_vars=small_errors.shape[1]) | ||
| ) | ||
| large_fig = vis.plot_error_heatmap( | ||
| large_errors, datastore=HeatmapDatastore(n_vars=large_errors.shape[1]) | ||
| ) | ||
|
|
||
| small_ax = small_fig.axes[0] | ||
| large_ax = large_fig.axes[0] | ||
|
|
||
| assert large_fig.get_size_inches()[0] > small_fig.get_size_inches()[0] | ||
| assert large_fig.get_size_inches()[1] > small_fig.get_size_inches()[1] | ||
| assert ( | ||
| large_ax.get_yticklabels()[0].get_fontsize() | ||
| < small_ax.get_yticklabels()[0].get_fontsize() | ||
| ) | ||
| assert large_ax.texts[0].get_fontsize() < small_ax.texts[0].get_fontsize() | ||
| assert large_ax.get_xticklabels()[0].get_rotation() == 45.0 | ||
|
|
||
| plt.close(small_fig) | ||
| plt.close(large_fig) | ||
|
|
||
|
|
||
| def test_plot_error_heatmap_skips_annotations_for_very_dense_grids(): |
There was a problem hiding this comment.
test_plot_error_heatmap_adapts_figure_and_font_sizes and test_plot_error_heatmap_skips_annotations_for_very_dense_grids share the same setup and exercise the same _compute_heatmap_layout path. Merge into one:
| def test_plot_error_heatmap_adapts_figure_and_font_sizes(): | |
| """Dense heatmaps should get more space and smaller text.""" | |
| small_errors = torch.ones((4, 5)) | |
| large_errors = torch.ones((20, 30)) | |
| small_fig = vis.plot_error_heatmap( | |
| small_errors, datastore=HeatmapDatastore(n_vars=small_errors.shape[1]) | |
| ) | |
| large_fig = vis.plot_error_heatmap( | |
| large_errors, datastore=HeatmapDatastore(n_vars=large_errors.shape[1]) | |
| ) | |
| small_ax = small_fig.axes[0] | |
| large_ax = large_fig.axes[0] | |
| assert large_fig.get_size_inches()[0] > small_fig.get_size_inches()[0] | |
| assert large_fig.get_size_inches()[1] > small_fig.get_size_inches()[1] | |
| assert ( | |
| large_ax.get_yticklabels()[0].get_fontsize() | |
| < small_ax.get_yticklabels()[0].get_fontsize() | |
| ) | |
| assert large_ax.texts[0].get_fontsize() < small_ax.texts[0].get_fontsize() | |
| assert large_ax.get_xticklabels()[0].get_rotation() == 45.0 | |
| plt.close(small_fig) | |
| plt.close(large_fig) | |
| def test_plot_error_heatmap_skips_annotations_for_very_dense_grids(): | |
| def test_plot_error_heatmap_adapts_layout_for_grid_size(): | |
| """Dense heatmaps get larger figures, smaller fonts, and suppress annotations.""" | |
| small_fig = vis.plot_error_heatmap(torch.ones((4, 5)), datastore=HeatmapDatastore(n_vars=5)) | |
| large_fig = vis.plot_error_heatmap(torch.ones((20, 30)), datastore=HeatmapDatastore(n_vars=30)) | |
| dense_fig = vis.plot_error_heatmap(torch.ones((40, 50)), datastore=HeatmapDatastore(n_vars=50)) | |
| assert large_fig.get_size_inches()[0] > small_fig.get_size_inches()[0] | |
| assert large_fig.axes[0].get_yticklabels()[0].get_fontsize() < small_fig.axes[0].get_yticklabels()[0].get_fontsize() | |
| assert large_fig.axes[0].get_xticklabels()[0].get_rotation() == 45.0 | |
| assert len(dense_fig.axes[0].texts) == 0 | |
| assert dense_fig.get_size_inches()[0] > 18.0 | |
| plt.close(small_fig); plt.close(large_fig); plt.close(dense_fig) |
|
Hey @sadamov addressed latest review changes, ready for review. |
|
Also ran all tests locally before |
sadamov
left a comment
There was a problem hiding this comment.
I really like these new maps! they will prove very useful. This PR is proposed for v0.7.0 and will be discussed at the next dev meeting and is marked as ready. @kshirajahere thanks for working on this over the course of the past weeks!
|
@kshirajahere I saw you added the docstrings asked for during the dev meeting. great! now we wait for #208 |
…parameter Adds `normalization: str = "state_std"` to `plot_error_heatmap` and rewrites `_get_heatmap_color_values` to support two explicit modes: - `"state_std"`: RMSE / state_std, white-to-red Reds colormap - `"one_step"`: RMSE / diff_std, white-to-red Reds colormap Both modes fall back to per-variable max (also Reds, [0,1]) when their required stat is missing. The two modes never silently upgrade to each other. Removes the custom `_HEATMAP_CMAP` in favour of the built-in `"Reds"` colormap string. Updates tests accordingly. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
|
Hi @kshirajahere, After the dev-meeting I took another look at some realistic heatmaps using the COSMO checkpoint from the boundary paper. And as @joeloskarsson mentioned, the heatmaps are really dominated by a few variables. Also, I don't think that the three-level silent fallback can be a bit dangeours as the user might not get the heatmap they expected. Because we already completed the PR review I implemented the changes here for you, not to cause you more work. But I would appreciate your review and if you agree merge with #376. Here are four plots for four different colormaps based on different scalings/normalizations (the skill one is not part of the PR, because we said we want only red tones 🔴 ) cosmo_skill_heatmap.pdf (omitted from PR) In this PR here we had one step diff as the default. But as you can see this is basically dominated by the pressure varaibles that change only very slightly in one forcast step. Sea Level pressure is actually a very important variable to assess model skill also in traditional NWP models. It contains the information about the whole atmospheric column and therefore about the layering and horizontal gradients -> wind. If you agree with my logic here and the PR into your branch, feel free to merge. Or let me know if you see some other reason why my default is bad. kshirajahere#1 |
|
Hey @sadamov went through the stacked PR. It looks good from what I have reviewed. (just few questions before merging, have raised them on kshirajahere#1) This feels like a cleaner match to the dev-meeting discussion and should make the heatmap behavior easier to reason about from the API and the colorbar label. Thanks a lot ❤️ |
Some suggestions for the colormap scaling
|
Merged kshirajahere#1 into this PR, Thanks @sadamov :D |







Describe your changes
This PR improves the metric heatmap plotting used during evaluation.
Summary of changes
plot_error_maptoplot_error_heatmapARModelcall site and related wording to use “heatmap” consistentlyunreleased -> FixedMotivation / context
Fixes the three issues raised in #375:
Dependencies
No new external dependencies.
Issue Link
Closes #375
Type of change
Checklist before requesting a review
pullwith--rebaseoption if possible).Checklist for reviewers
Each PR comes with its own improvements and flaws. The reviewer should check the following:
Author checklist after completed review
reflecting type of change (add section where missing):
Checklist for assignee
Tests
Ran locally:
python -m pytest -q tests/test_plotting.py -k "heatmap"-> passedpython -m py_compile neural_lam/vis.py neural_lam/models/ar_model.py tests/test_plotting.py-> passedNotes:
python -m pytest -q tests/test_plotting.pyhits an existing cartopy/Natural Earth data path permission issue on this machine when the older integration plotting tests try to access coastline data. The new focused heatmap tests pass.