diff --git a/CHANGELOG.md b/CHANGELOG.md index 553475ef..3cd17073 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Change metric heatmap (`plot_error_map`, now `plot_error_heatmap`) to use a + shared cross-variable color scale instead of per-row normalization, add a + colorbar, and scale figure size and font sizes with grid dimensions + ([#375](https://github.com/mllam/neural-lam/issues/375)) + - Resolve `xarray` `FacetGrid` `DeprecationWarning` in `plot_example.py` by using a compatibility shim [\#482](https://github.com/mllam/neural-lam/pull/482) @sohampatil01-svg ## [v0.6.0](https://github.com/mllam/neural-lam/releases/tag/v0.6.0) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 80356c60..6f17f638 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -383,7 +383,7 @@ def on_validation_epoch_end(self): """ Compute val metrics at the end of val epoch """ - # Create error maps for all test metrics + # Create error heatmaps for all validation metrics self.aggregate_and_plot_metrics(self.val_metrics, prefix="val") if self.trainer.is_global_zero: @@ -436,9 +436,10 @@ def test_step(self, batch, batch_idx): batch_size=batch[0].shape[0], ) - # Compute all evaluation metrics for error maps Note: explicitly list - # metrics here, as test_metrics can contain additional ones, computed - # differently, but that should be aggregated on_test_epoch_end + # Compute all evaluation metrics for error heatmaps. Note: + # explicitly list metrics here, as test_metrics can contain + # additional ones, computed differently, but that should be + # aggregated on_test_epoch_end for metric_name in ("mse", "mae"): metric_func = metrics.get_metric(metric_name) batch_metric_vals = metric_func( @@ -669,7 +670,7 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): Return: log_dict: dict with everything to log for given metric """ log_dict = {} - metric_fig = vis.plot_error_map( + metric_fig = vis.plot_error_heatmap( errors=metric_tensor, datastore=self._datastore, ) @@ -702,7 +703,8 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): def aggregate_and_plot_metrics(self, metrics_dict, prefix): """ - Aggregate and create error map plots for all metrics in metrics_dict + Aggregate and create error heatmap plots for all metrics in + metrics_dict metrics_dict: dictionary with metric_names and list of tensors with step-evals. @@ -760,7 +762,7 @@ def on_test_epoch_end(self): Compute test metrics and make plots at the end of test epoch. Will gather stored tensors and perform plotting and logging on rank 0. """ - # Create error maps for all test metrics + # Create error heatmaps for all test metrics self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") # Plot spatial loss maps diff --git a/neural_lam/vis.py b/neural_lam/vis.py index e0d00815..79d71064 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -1,3 +1,6 @@ +# Standard library +import warnings + # Third-party import cartopy.crs as ccrs import cartopy.feature as cfeature @@ -12,11 +15,20 @@ from . import utils from .datastore.base import BaseRegularGridDatastore -# Font sizes shared across all plot functions for visual consistency. +# Font sizes shared across projection-aware plot functions. _TITLE_SIZE = 13 # suptitle and per-axes titles _LABEL_SIZE = 11 # axis / colorbar labels _TICK_SIZE = 11 # tick labels +# Annotations become unreadable when cells are smaller than this (in points) +# or when the total number of cells exceeds a readable count. +_MIN_CELL_SIZE_FOR_ANNOTATIONS = 18 +_MAX_CELLS_FOR_ANNOTATIONS = 800 +_HEATMAP_CMAP = matplotlib.colors.LinearSegmentedColormap.from_list( + "error_heatmap_white_red", + ["#ffffff", "#fee5d9", "#fcae91", "#fb6a4a", "#cb181d"], +) + def _tex_safe(s: str) -> str: """Escape TeX special characters in s if TeX rendering is currently active. @@ -29,6 +41,188 @@ def _tex_safe(s: str) -> str: return s +def _compute_heatmap_layout(n_rows: int, n_cols: int) -> dict[str, float]: + """Choose figure and font sizes from the heatmap dimensions.""" + max_dim = max(n_rows, n_cols) + + # Size the figure so each cell gets ~0.8 x 0.5 inches; floor at 8 x 4.5. + fig_width = float(max(4.5 + 0.8 * n_cols, 8.0)) + fig_height = float(max(2.5 + 0.5 * n_rows, 4.5)) + + # Approximate cell size in points (72 pt/inch) to decide whether + # in-cell annotations will be legible. + cell_w_pt = (fig_width / max(n_cols, 1)) * 72 + cell_h_pt = (fig_height / max(n_rows, 1)) * 72 + show_annotations = ( + n_rows * n_cols <= _MAX_CELLS_FOR_ANNOTATIONS + and min(cell_w_pt, cell_h_pt) >= _MIN_CELL_SIZE_FOR_ANNOTATIONS + ) + + return { + "fig_width": fig_width, + "fig_height": fig_height, + "tick_label_size": float(np.clip(15.0 - 0.18 * max_dim, 9.0, 14.0)), + "annotation_size": float(np.clip(13.0 - 0.22 * max_dim, 5.0, 12.0)), + "title_size": float(np.clip(16.0 - 0.15 * max_dim, 9.0, 15.0)), + "x_tick_rotation": 45.0 if n_cols > 12 else 0.0, + "show_annotations": show_annotations, + } + + +def _get_heatmap_var_labels(datastore: BaseRegularGridDatastore) -> list[str]: + """Build state-variable labels from datastore metadata.""" + var_names = datastore.get_vars_names(category="state") + var_units = datastore.get_vars_units(category="state") + return [ + _tex_safe(f"{name} ({unit})" if unit else name) + for name, unit in zip(var_names, var_units) + ] + + +def _to_heatmap_matrix(values) -> np.ndarray: + """ + Convert heatmap inputs to a `(d_f, pred_steps)` matrix. + + A single-step tensor may arrive as one-dimensional `(d_f,)`, especially in + single-GPU or focused metric logging paths. In that case we first treat it + as one row of `(pred_steps=1, d_f)` before transposing. + """ + if hasattr(values, "detach"): + values = values.detach().cpu().numpy() + values = np.asarray(values, dtype=float) + if values.ndim == 1: + values = values[np.newaxis, :] + return values.T + + +def _get_feature_scale( + ds_stats: xr.Dataset, var_name: str, n_vars: int +) -> np.ndarray | None: + """Extract a 1D per-feature scale, averaging over any extra dims.""" + if var_name not in ds_stats: + return None + + da_scale = ds_stats[var_name] + feature_dim = "state_feature" + if feature_dim not in da_scale.dims: + return None + + reduce_dims = [dim for dim in da_scale.dims if dim != feature_dim] + if reduce_dims: + da_scale = da_scale.mean(dim=reduce_dims) + + scale = np.asarray(da_scale.values, dtype=float).reshape(-1) + if scale.size < n_vars: + return None + + return scale[:n_vars] + + +def _get_heatmap_color_values( + errors_np: np.ndarray, + datastore: BaseRegularGridDatastore, + normalization: str, +) -> tuple[np.ndarray, str, matplotlib.colors.Colormap]: + """ + Normalize heatmap colors according to `normalization`. + + Returns a 3-tuple: (color_values, colorbar_label, cmap). + The returned array drives the colormap only; the numeric annotations in the + heatmap remain the original (physical-unit) values passed in `errors_np`. + + Both modes fall back to per-variable max normalization when their required + stat is unavailable, appending "[fallback]" to the colorbar label. + """ + + def _per_var_fallback(): + max_err = errors_np.max(axis=1, keepdims=True) + safe = np.where(max_err > np.finfo(float).eps, max_err, 1.0) + return ( + errors_np / safe, + "Per-variable scale (relative to max error) [fallback]", + _HEATMAP_CMAP, + ) + + try: + ds_stats = datastore.get_standardization_dataarray(category="state") + except (AttributeError, KeyError, TypeError, ValueError) as exc: + warnings.warn( + f"Could not load standardization stats ({exc}); " + "falling back to per-variable scale.", + UserWarning, + stacklevel=3, + ) + return _per_var_fallback() + + n_vars = errors_np.shape[0] + + if normalization == "state_std": + state_std = _get_feature_scale(ds_stats, "state_std", n_vars) + if state_std is None: + warnings.warn( + "state_std unavailable; falling back to per-variable scale.", + UserWarning, + stacklevel=3, + ) + return _per_var_fallback() + safe_std = np.where( + np.isfinite(state_std) & (state_std > np.finfo(float).eps), + state_std, + 1.0, + ) + return errors_np / safe_std[:, None], "Error / state_std", _HEATMAP_CMAP + + if normalization == "diff_std": + diff_std_std = _get_feature_scale( + ds_stats, "state_diff_std_standardized", n_vars + ) + if diff_std_std is None: + warnings.warn( + "state_diff_std_standardized unavailable; " + "falling back to per-variable scale.", + UserWarning, + stacklevel=3, + ) + return _per_var_fallback() + state_std = _get_feature_scale(ds_stats, "state_std", n_vars) + if state_std is None: + warnings.warn( + "state_std unavailable (needed to recover physical diff_std); " + "falling back to per-variable scale.", + UserWarning, + stacklevel=3, + ) + return _per_var_fallback() + scale = state_std * diff_std_std # physical diff_std + safe = np.where( + np.isfinite(scale) & (np.abs(scale) > np.finfo(float).eps), + scale, + 1.0, + ) + return ( + errors_np / safe[:, None], + "Error / physical diff_std", + _HEATMAP_CMAP, + ) + + raise ValueError( + f"Unknown normalization {normalization!r}; " + "expected 'state_std' or 'diff_std'." + ) + + +def _get_annotation_text_color( + value: float, image: matplotlib.image.AxesImage +) -> str: + """Choose a readable annotation color from the rendered background.""" + if not np.isfinite(value): + return "black" + + rgba = image.cmap(image.norm(value)) + luminance = 0.2126 * rgba[0] + 0.7152 * rgba[1] + 0.0722 * rgba[2] + return "white" if luminance < 0.5 else "black" + + def plot_on_axis( ax, da, @@ -40,7 +234,7 @@ def plot_on_axis( boundary_alpha=None, crop_to_interior=False, ): - """Plot weather state on given axis using datastore metadata. + """Plot weather state on a projection-aware axis using datastore metadata. Parameters ---------- @@ -67,9 +261,7 @@ def plot_on_axis( ------- matplotlib.collections.QuadMesh The mesh object created by pcolormesh. - """ - ax.coastlines(resolution="50m") ax.add_feature(cfeature.BORDERS, linestyle="-", alpha=0.5) @@ -154,68 +346,130 @@ def plot_on_axis( return mesh +# rc_context applies NeurIPS font/text settings; figure size is overridden +# by the explicit figsize= below but font family and usetex stay in effect. @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): - """ - Plot a heatmap of errors of different variables at different - predictions horizons - errors: (pred_steps, d_f) +def plot_error_heatmap( + errors, + datastore: BaseRegularGridDatastore, + title=None, + normalization: str = "state_std", +): """ + Plot a heatmap of errors for state variables across forecast lead times. - # Ensure errors is 2D even for single-step/single-GPU runs - if errors.dim() == 1: - errors = errors.unsqueeze(0) - - errors_np = errors.T.cpu().numpy() # (d_f, pred_steps) + Parameters + ---------- + errors : torch.Tensor + Error values with shape `(pred_steps, d_f)`. These values are used for + the numeric annotations in each cell. + datastore : BaseRegularGridDatastore + Datastore providing step length and variable metadata. + title : str, optional + Optional title for the figure. + normalization : {"state_std", "diff_std"}, default "state_std" + Color scaling mode. "state_std" divides by climatological std; + "diff_std" divides by the typical one-step change magnitude. Both fall + back to per-variable max error when the required stats are missing. + + Notes + ----- + Color scaling is controlled by `normalization`; see + `_get_heatmap_color_values` for the full fallback logic. When stats are + unavailable the colorbar label includes "[fallback]". + """ + errors_np = _to_heatmap_matrix(errors) d_f, pred_steps = errors_np.shape step_length = datastore.step_length - # Normalize all errors to [0,1] for color map - max_errors = errors_np.max(axis=1) # d_f - errors_norm = errors_np / np.expand_dims(max_errors, axis=1) - time_step_int, time_step_unit = utils.get_integer_time(step_length) + layout = _compute_heatmap_layout(n_rows=d_f, n_cols=pred_steps) + color_values_np, colorbar_label, heatmap_cmap = _get_heatmap_color_values( + errors_np, datastore, normalization + ) + finite_color_values = color_values_np[np.isfinite(color_values_np)] + if finite_color_values.size == 0: + vmin, vmax = 0.0, 1.0 + else: + vmin = float(finite_color_values.min()) + vmax = float(finite_color_values.max()) + vmin = min(0.0, vmin) + if np.isclose(vmin, vmax): + vmax = vmin + 1.0 - fig, ax = plt.subplots(figsize=(15, 10)) + fig, ax = plt.subplots( + figsize=(layout["fig_width"], layout["fig_height"]), + constrained_layout=True, + ) - ax.imshow( - errors_norm, - cmap="OrRd", - vmin=0, - vmax=1.0, + im = ax.imshow( + color_values_np, + cmap=heatmap_cmap, + vmin=vmin, + vmax=vmax, interpolation="none", aspect="auto", - alpha=0.8, ) + cbar = fig.colorbar(im, ax=ax, pad=0.02) + cbar.set_label(_tex_safe(colorbar_label), size=layout["tick_label_size"]) + cbar.ax.tick_params(labelsize=layout["tick_label_size"]) + cbar.ax.yaxis.get_offset_text().set_fontsize(layout["tick_label_size"]) + + if layout["show_annotations"]: + for (j, i), error in np.ndenumerate(errors_np): + if np.isfinite(error): + formatted_error = ( + f"{error:.3g}" if abs(error) < 1.0e4 else f"{error:.2E}" + ) + else: + formatted_error = str(error) + text_color = _get_annotation_text_color(color_values_np[j, i], im) + ax.text( + i, + j, + formatted_error, + ha="center", + va="center", + usetex=False, + fontsize=layout["annotation_size"], + color=text_color, + ) - # ax and labels - for (j, i), error in np.ndenumerate(errors_np): - # Numbers > 9999 will be too large to fit - formatted_error = f"{error:.3f}" if error < 9999 else f"{error:.2E}" - ax.text(i, j, formatted_error, ha="center", va="center", usetex=False) - - # Ticks and labels ax.set_xticks(np.arange(pred_steps)) pred_hor_i = np.arange(pred_steps) + 1 pred_hor_h = time_step_int * pred_hor_i - ax.set_xticklabels(pred_hor_h, size=_TICK_SIZE) - ax.set_xlabel(f"Lead time ({time_step_unit[0]})", size=_LABEL_SIZE) + ax.set_xticklabels( + pred_hor_h, + size=layout["tick_label_size"], + rotation=layout["x_tick_rotation"], + ha="right" if layout["x_tick_rotation"] > 0 else "center", + ) + ax.set_xlabel( + f"Lead time ({time_step_unit[0]})", size=layout["tick_label_size"] + ) ax.set_yticks(np.arange(d_f)) - var_names = datastore.get_vars_names(category="state") - var_units = datastore.get_vars_units(category="state") - y_ticklabels = [ - _tex_safe(f"{name} ({unit})") - for name, unit in zip(var_names, var_units) - ] - ax.set_yticklabels(y_ticklabels, rotation=30, size=_TICK_SIZE) + ax.set_yticklabels( + _get_heatmap_var_labels(datastore=datastore), + size=layout["tick_label_size"], + ) if title: - ax.set_title(title, size=_TITLE_SIZE) + ax.set_title(title, size=layout["title_size"]) return fig +def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): + """Deprecated: use :func:`plot_error_heatmap` instead.""" + warnings.warn( + "plot_error_map is deprecated, use plot_error_heatmap instead", + DeprecationWarning, + stacklevel=2, + ) + return plot_error_heatmap(errors, datastore=datastore, title=title) + + @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( datastore: BaseRegularGridDatastore, diff --git a/tests/test_plotting.py b/tests/test_plotting.py index ba03857c..dd8e5e20 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -5,6 +5,7 @@ from unittest.mock import patch # Third-party +import matplotlib import matplotlib.figure import matplotlib.pyplot as plt import numpy as np @@ -27,6 +28,12 @@ TEST_OUTPUT_DIR.mkdir(parents=True, exist_ok=True) +@pytest.fixture(scope="session", autouse=True) +def _set_agg_backend(): + """Use non-interactive backend for all plotting tests.""" + plt.switch_backend("Agg") + + @pytest.fixture(autouse=True) def mock_cartopy_downloads(monkeypatch: pytest.MonkeyPatch) -> None: """ @@ -47,6 +54,50 @@ def close_all_figures_after_test() -> Iterator[None]: plt.close("all") +class HeatmapDatastore: + """Minimal datastore stub for error-heatmap plotting tests.""" + + def __init__( + self, + n_vars, + step_length=timedelta(hours=1), + state_std=None, + state_diff_std_standardized=None, + ): + self._n_vars = n_vars + self.step_length = step_length + self._state_std = ( + np.asarray(state_std, dtype=float) + if state_std is not None + else np.ones(n_vars, dtype=float) + ) + self._state_diff_std_standardized = ( + np.asarray(state_diff_std_standardized, dtype=float) + if state_diff_std_standardized is not None + else np.ones(n_vars, dtype=float) + ) + + def get_vars_names(self, category): + assert category == "state" + return [f"state_var_{i}" for i in range(self._n_vars)] + + def get_vars_units(self, category): + assert category == "state" + return ["unit"] * self._n_vars + + def get_standardization_dataarray(self, category): + assert category == "state" + return xr.Dataset( + { + "state_std": (("state_feature",), self._state_std), + "state_diff_std_standardized": ( + ("state_feature",), + self._state_diff_std_standardized, + ), + } + ) + + def test_plot_prediction() -> None: """Check prediction plot structure, titles and shared color scaling.""" datastore = init_datastore_example("dummydata") @@ -86,7 +137,7 @@ def test_plot_prediction() -> None: def test_plot_error_map() -> None: - """Check error heatmap content, labels and annotations.""" + """Check the deprecated error-heatmap wrapper still renders correctly.""" datastore = init_datastore_example("dummydata") d_f = len(datastore.get_vars_names(category="state")) pred_steps = 4 @@ -95,16 +146,17 @@ def test_plot_error_map() -> None: pred_steps, d_f ) - fig = vis.plot_error_map( - errors=errors, - datastore=datastore, - title="Test Error Map", - ) + with pytest.warns(DeprecationWarning, match="plot_error_heatmap"): + fig = vis.plot_error_map( + errors=errors, + datastore=datastore, + title="Test Error Map", + ) assert isinstance(fig, matplotlib.figure.Figure) - assert len(fig.axes) == 1 + assert len(fig.axes) == 2 - ax = fig.axes[0] + ax, colorbar_ax = fig.axes assert len(ax.images) == 1 assert ax.images[0].get_array().shape == (d_f, pred_steps) assert ax.get_xlabel() == "Lead time (h)" @@ -123,6 +175,165 @@ def test_plot_error_map() -> None: assert actual_y_ticklabels == expected_y_ticklabels assert len(ax.texts) == pred_steps * d_f + assert colorbar_ax.get_ylabel() != "" + + +def test_plot_error_heatmap_state_std_normalization(): + """state_std mode: colors are Error / state_std.""" + errors = torch.tensor( + [ + [1.0, 100.0, 10.0], + [2.0, 80.0, 5.0], + [3.0, 60.0, 2.5], + ] + ) + datastore = HeatmapDatastore( + n_vars=errors.shape[1], + state_std=[1.0, 100.0, 10.0], + state_diff_std_standardized=[1.0, 2.0, 0.5], + ) + + fig = vis.plot_error_heatmap( + errors, datastore=datastore, normalization="state_std" + ) + ax = fig.axes[0] + image = ax.images[0] + colorbar = fig.axes[1] + + expected = errors.T.numpy() / np.array([[1.0], [100.0], [10.0]]) + np.testing.assert_allclose(image.get_array(), expected) + assert colorbar.get_ylabel() == "Error / state_std" + + plt.close(fig) + + +def test_plot_error_heatmap_diff_std_normalization(): + """diff_std mode: colors are Error / physical diff_std.""" + errors = torch.tensor( + [ + [1.0, 100.0, 10.0], + [2.0, 80.0, 5.0], + [3.0, 60.0, 2.5], + ] + ) + datastore = HeatmapDatastore( + n_vars=errors.shape[1], + state_std=[1.0, 100.0, 10.0], + state_diff_std_standardized=[1.0, 2.0, 0.5], + ) + + fig = vis.plot_error_heatmap( + errors, datastore=datastore, normalization="diff_std" + ) + ax = fig.axes[0] + image = ax.images[0] + colorbar = fig.axes[1] + + # physical diff_std = state_std * state_diff_std_standardized = [1, 200, 5] + expected = errors.T.numpy() / np.array([[1.0], [200.0], [5.0]]) + np.testing.assert_allclose(image.get_array(), expected) + assert colorbar.get_ylabel() == "Error / physical diff_std" + + plt.close(fig) + + +def test_plot_error_heatmap_falls_back_to_per_var_scale_without_stats(): + """Colors fall back to per-variable max when stats are unavailable.""" + + class NoStatsHeatmapDatastore(HeatmapDatastore): + def get_standardization_dataarray(self, category): + raise KeyError("Missing standardization stats") + + errors = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + datastore = NoStatsHeatmapDatastore(n_vars=errors.shape[1]) + + with pytest.warns(UserWarning, match="falling back to per-variable scale"): + fig = vis.plot_error_heatmap(errors, datastore=datastore) + ax = fig.axes[0] + image = ax.images[0] + colorbar = fig.axes[1] + + # errors_np after transpose: var0=[1,3], var1=[2,4]; max per var: [3,4] + expected = np.array([[1 / 3, 3 / 3], [2 / 4, 4 / 4]]) + np.testing.assert_allclose(image.get_array(), expected) + assert "[fallback]" in colorbar.get_ylabel() + + plt.close(fig) + + +def test_plot_error_heatmap_diff_std_falls_back_when_diff_std_absent(): + """diff_std mode falls back to per-var max when diff_std is missing.""" + + class StateStdOnlyDatastore(HeatmapDatastore): + def get_standardization_dataarray(self, category): + return ( + super() + .get_standardization_dataarray(category) + .drop_vars("state_diff_std_standardized") + ) + + errors = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + datastore = StateStdOnlyDatastore(n_vars=2, state_std=[2.0, 1.0]) + + with pytest.warns(UserWarning, match="falling back to per-variable scale"): + fig = vis.plot_error_heatmap( + errors, datastore=datastore, normalization="diff_std" + ) + colorbar = fig.axes[1] + assert "[fallback]" in colorbar.get_ylabel() + plt.close(fig) + + +def test_plot_error_heatmap_state_std_ignores_diff_std(): + """state_std mode is unaffected by presence or absence of diff_std.""" + + class StateStdOnlyDatastore(HeatmapDatastore): + def get_standardization_dataarray(self, category): + return ( + super() + .get_standardization_dataarray(category) + .drop_vars("state_diff_std_standardized") + ) + + errors = torch.tensor([[2.0, 4.0], [1.0, 3.0]]) + datastore = StateStdOnlyDatastore(n_vars=2, state_std=[2.0, 1.0]) + fig = vis.plot_error_heatmap( + errors, datastore=datastore, normalization="state_std" + ) + ax = fig.axes[0] + image = ax.images[0] + + # errors_np after transpose: var0=[2,1], var1=[4,3]; state_std=[2,1] + expected = np.array([[2 / 2, 1 / 2], [4 / 1, 3 / 1]]) + np.testing.assert_allclose(image.get_array(), expected) + assert fig.axes[1].get_ylabel() == "Error / state_std" + plt.close(fig) + + +def test_plot_error_heatmap_adapts_layout_for_grid_size(): + """Dense heatmaps adapt size, font scale, and annotation density.""" + small_fig = vis.plot_error_heatmap( + torch.ones((4, 5)), datastore=HeatmapDatastore(n_vars=5) + ) + large_fig = vis.plot_error_heatmap( + torch.ones((20, 30)), datastore=HeatmapDatastore(n_vars=30) + ) + dense_fig = vis.plot_error_heatmap( + torch.ones((40, 50)), datastore=HeatmapDatastore(n_vars=50) + ) + + assert large_fig.get_size_inches()[0] > small_fig.get_size_inches()[0] + assert ( + large_fig.axes[0].get_yticklabels()[0].get_fontsize() + < small_fig.axes[0].get_yticklabels()[0].get_fontsize() + ) + assert large_fig.axes[0].get_xticklabels()[0].get_rotation() == 45.0 + assert len(dense_fig.axes[0].texts) == 0 + assert dense_fig.get_size_inches()[0] > 18.0 + + plt.close(small_fig) + plt.close(large_fig) + plt.close(dense_fig) def test_plot_spatial_error() -> None: @@ -141,14 +352,12 @@ def test_plot_spatial_error() -> None: ) assert isinstance(fig, matplotlib.figure.Figure) - # GeoAxes + colorbar axes assert len(fig.axes) == 2 assert fig.texts[0].get_text() == "Test Spatial Error" def test_plot_spatial_error_crop_to_interior_changes_extent() -> None: - """Check interior cropping forwards interior lon/lat bounds to - set_extent.""" + """Check interior cropping forwards interior lon/lat bounds.""" datastore = init_datastore_example("dummydata") n_grid = datastore.num_grid_points grid_shape = (datastore.grid_shape_state.x, datastore.grid_shape_state.y) @@ -194,15 +403,15 @@ def test_plot_spatial_error_crop_to_interior_changes_extent() -> None: @pytest.fixture def model_and_batch(tmp_path, time_step, time_unit): - """Setup a model and dataset for testing plot_examples""" - # Create timedelta with specified step length + """Setup a model and dataset for testing plot_examples.""" + # Create timedelta with specified step length. step_length_kwargs = {time_unit: time_step} step_length = timedelta(**step_length_kwargs) - # Create datastore with specified step_length + # Create datastore with specified step length. datastore = DummyDatastore(step_length=step_length) - # Create minimal model args + # Create minimal model args. class ModelArgs: output_std = False loss = "mse" @@ -221,6 +430,7 @@ class ModelArgs: num_future_forcing_steps = 0 var_leads_metrics_watch = {} + # Create graph files if they do not already exist. graph_dir_path = Path(datastore.root_path) / "graph" / "1level" if not graph_dir_path.exists(): create_graph_from_datastore( @@ -229,7 +439,7 @@ class ModelArgs: n_max_levels=1, ) - # Create config + # Create config. config = nlconfig.NeuralLAMConfig( datastore=nlconfig.DatastoreSelection( kind=datastore.SHORT_NAME, @@ -237,14 +447,14 @@ class ModelArgs: ), ) - # Create model + # Create model. model = GraphLAM( args=ModelArgs(), config=config, datastore=datastore, ) - # Create dataset to get a sample batch + # Create dataset to get a sample batch. dataset = WeatherDataset( datastore=datastore, split="train", @@ -253,9 +463,9 @@ class ModelArgs: num_future_forcing_steps=0, ) - # Get a batch (just use one sample) + # Get a batch (just use one sample). sample = dataset[0] - batch = tuple(torch.stack([item]) for item in sample) # Add batch dimension + batch = tuple(torch.stack([item]) for item in sample) return model, batch, datastore, tmp_path @@ -273,13 +483,13 @@ class ModelArgs: def test_plot_examples_integration_saves_figure( model_and_batch, time_step, time_unit, t_i ): - """Integration test that saves actual figure for manual inspection""" + """Integration test that saves an actual figure for manual inspection.""" model, batch, datastore, tmp_path = model_and_batch - # Reset plotted examples counter + # Reset plotted examples counter. model.plotted_examples = 0 - # Verify that the model correctly inferred time step from datastore + # Verify that the model correctly inferred time step from the datastore. assert ( model.time_step_int == time_step ), f"Expected time_step_int={time_step}, got {model.time_step_int}" @@ -287,19 +497,19 @@ def test_plot_examples_integration_saves_figure( model.time_step_unit == time_unit ), f"Expected time_step_unit={time_unit}, got {model.time_step_unit}" - # Generate prediction + # Generate prediction. prediction, target, _, _ = model.common_step(batch) - # Rescale to original data scale + # Rescale to original data scale. prediction_rescaled = prediction * model.state_std + model.state_mean target_rescaled = target * model.state_std + model.state_mean - # Get first example - pred_slice = prediction_rescaled[0].detach() # Detach from graph + # Get first example. + pred_slice = prediction_rescaled[0].detach() target_slice = target_rescaled[0].detach() time_slice = batch[3][0] - # Create DataArrays + # Create DataArrays. dataset = WeatherDataset(datastore=datastore, split="train") time = np.array(time_slice.cpu(), dtype="datetime64[ns]") @@ -312,7 +522,7 @@ def test_plot_examples_integration_saves_figure( tensor=target_slice, time=time, category="state" ).unstack("grid_index") - # Get vranges + # Get vranges. var_vmin = ( torch.minimum( pred_slice.flatten(0, 1).min(dim=0)[0], @@ -331,10 +541,10 @@ def test_plot_examples_integration_saves_figure( ) var_vranges = list(zip(var_vmin, var_vmax)) - # Create plot for specified timestep and first variable var_names = datastore.get_vars_names("state") var_units = datastore.get_vars_units("state") + # Create plot for specified timestep and first variable. fig = vis.plot_prediction( datastore=datastore, title=f"{var_names[0]}, t={t_i} ({time_step * t_i} {time_unit})", @@ -346,7 +556,7 @@ def test_plot_examples_integration_saves_figure( da_target=da_target.isel(state_feature=0, time=t_i - 1).squeeze(), ) - # Save for inspection + # Save for inspection. output_path = ( TEST_OUTPUT_DIR / f"ar_model_integration_t{t_i}_{time_step}{time_unit}.png" @@ -356,7 +566,7 @@ def test_plot_examples_integration_saves_figure( plt.close(fig) - # Verify the figure was created + # Verify the figure was created. assert fig is not None assert isinstance(fig, plt.Figure) assert output_path.exists()