diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 3e9c012d..79d71064 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -119,62 +119,96 @@ def _get_feature_scale( def _get_heatmap_color_values( - errors_np: np.ndarray, datastore: BaseRegularGridDatastore -) -> tuple[np.ndarray, str]: + errors_np: np.ndarray, + datastore: BaseRegularGridDatastore, + normalization: str, +) -> tuple[np.ndarray, str, matplotlib.colors.Colormap]: """ - Normalize heatmap colors to a cross-variable relative scale. + Normalize heatmap colors according to `normalization`. + Returns a 3-tuple: (color_values, colorbar_label, cmap). The returned array drives the colormap only; the numeric annotations in the heatmap remain the original (physical-unit) values passed in `errors_np`. - Scaling logic: - - Prefer a relative scale based on datastore standardization stats. - Start with `state_std` (per-variable climatological std). - - If `state_diff_std_standardized` is available, also fold that in to - represent error relative to typical one-step variability. - - If any required stats are missing or invalid, fall back to absolute - scaling using the raw error values. + Both modes fall back to per-variable max normalization when their required + stat is unavailable, appending "[fallback]" to the colorbar label. """ - try: - ds_state_stats = datastore.get_standardization_dataarray( - category="state" + + def _per_var_fallback(): + max_err = errors_np.max(axis=1, keepdims=True) + safe = np.where(max_err > np.finfo(float).eps, max_err, 1.0) + return ( + errors_np / safe, + "Per-variable scale (relative to max error) [fallback]", + _HEATMAP_CMAP, ) + + try: + ds_stats = datastore.get_standardization_dataarray(category="state") except (AttributeError, KeyError, TypeError, ValueError) as exc: warnings.warn( f"Could not load standardization stats ({exc}); " - "falling back to absolute scale.", + "falling back to per-variable scale.", UserWarning, stacklevel=3, ) - return errors_np, "Absolute scale" + return _per_var_fallback() n_vars = errors_np.shape[0] - state_std = _get_feature_scale(ds_state_stats, "state_std", n_vars) - if state_std is None: - warnings.warn( - "State standardization stats are unavailable; " - "falling back to absolute scale.", - UserWarning, - stacklevel=3, + + if normalization == "state_std": + state_std = _get_feature_scale(ds_stats, "state_std", n_vars) + if state_std is None: + warnings.warn( + "state_std unavailable; falling back to per-variable scale.", + UserWarning, + stacklevel=3, + ) + return _per_var_fallback() + safe_std = np.where( + np.isfinite(state_std) & (state_std > np.finfo(float).eps), + state_std, + 1.0, ) - return errors_np, "Absolute scale" + return errors_np / safe_std[:, None], "Error / state_std", _HEATMAP_CMAP - scale = state_std - colorbar_label = "Relative scale (state stds)" + if normalization == "diff_std": + diff_std_std = _get_feature_scale( + ds_stats, "state_diff_std_standardized", n_vars + ) + if diff_std_std is None: + warnings.warn( + "state_diff_std_standardized unavailable; " + "falling back to per-variable scale.", + UserWarning, + stacklevel=3, + ) + return _per_var_fallback() + state_std = _get_feature_scale(ds_stats, "state_std", n_vars) + if state_std is None: + warnings.warn( + "state_std unavailable (needed to recover physical diff_std); " + "falling back to per-variable scale.", + UserWarning, + stacklevel=3, + ) + return _per_var_fallback() + scale = state_std * diff_std_std # physical diff_std + safe = np.where( + np.isfinite(scale) & (np.abs(scale) > np.finfo(float).eps), + scale, + 1.0, + ) + return ( + errors_np / safe[:, None], + "Error / physical diff_std", + _HEATMAP_CMAP, + ) - state_diff_std_standardized = _get_feature_scale( - ds_state_stats, "state_diff_std_standardized", n_vars + raise ValueError( + f"Unknown normalization {normalization!r}; " + "expected 'state_std' or 'diff_std'." ) - if state_diff_std_standardized is not None: - scale = scale * state_diff_std_standardized - colorbar_label = "Error / Std(1-step change)" - - safe_scale = np.where( - np.isfinite(scale) & (np.abs(scale) > np.finfo(float).eps), - scale, - 1.0, - ) - return errors_np / safe_scale[:, None], colorbar_label def _get_annotation_text_color( @@ -319,6 +353,7 @@ def plot_error_heatmap( errors, datastore: BaseRegularGridDatastore, title=None, + normalization: str = "state_std", ): """ Plot a heatmap of errors for state variables across forecast lead times. @@ -332,13 +367,16 @@ def plot_error_heatmap( Datastore providing step length and variable metadata. title : str, optional Optional title for the figure. + normalization : {"state_std", "diff_std"}, default "state_std" + Color scaling mode. "state_std" divides by climatological std; + "diff_std" divides by the typical one-step change magnitude. Both fall + back to per-variable max error when the required stats are missing. Notes ----- - The heatmap colormap is driven by a relative cross-variable scale derived - from datastore standardization stats (see `_get_heatmap_color_values`). - If those stats are unavailable, the plot falls back to absolute scaling - on the raw `errors` values. + Color scaling is controlled by `normalization`; see + `_get_heatmap_color_values` for the full fallback logic. When stats are + unavailable the colorbar label includes "[fallback]". """ errors_np = _to_heatmap_matrix(errors) d_f, pred_steps = errors_np.shape @@ -346,18 +384,16 @@ def plot_error_heatmap( time_step_int, time_step_unit = utils.get_integer_time(step_length) layout = _compute_heatmap_layout(n_rows=d_f, n_cols=pred_steps) - color_values_np, colorbar_label = _get_heatmap_color_values( - errors_np, datastore + color_values_np, colorbar_label, heatmap_cmap = _get_heatmap_color_values( + errors_np, datastore, normalization ) - finite_color_values = color_values_np[np.isfinite(color_values_np)] if finite_color_values.size == 0: vmin, vmax = 0.0, 1.0 else: vmin = float(finite_color_values.min()) vmax = float(finite_color_values.max()) - if vmin >= 0.0: - vmin = 0.0 + vmin = min(0.0, vmin) if np.isclose(vmin, vmax): vmax = vmin + 1.0 @@ -368,7 +404,7 @@ def plot_error_heatmap( im = ax.imshow( color_values_np, - cmap=_HEATMAP_CMAP, + cmap=heatmap_cmap, vmin=vmin, vmax=vmax, interpolation="none", diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 2e7429b0..dd8e5e20 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -178,8 +178,8 @@ def test_plot_error_map() -> None: assert colorbar_ax.get_ylabel() != "" -def test_plot_error_heatmap_uses_relative_color_scale(): - """Heatmap colors should compare relative magnitudes across variables.""" +def test_plot_error_heatmap_state_std_normalization(): + """state_std mode: colors are Error / state_std.""" errors = torch.tensor( [ [1.0, 100.0, 10.0], @@ -193,23 +193,52 @@ def test_plot_error_heatmap_uses_relative_color_scale(): state_diff_std_standardized=[1.0, 2.0, 0.5], ) - fig = vis.plot_error_heatmap(errors, datastore=datastore) + fig = vis.plot_error_heatmap( + errors, datastore=datastore, normalization="state_std" + ) ax = fig.axes[0] image = ax.images[0] colorbar = fig.axes[1] - expected_color_values = errors.T.numpy() / np.array([[1.0], [200.0], [5.0]]) - np.testing.assert_allclose(image.get_array(), expected_color_values) - assert image.norm.vmin == 0.0 - assert image.norm.vmax == pytest.approx(expected_color_values.max()) - assert len(fig.axes) == 2 - assert colorbar.get_ylabel() == "Error / Std(1-step change)" + expected = errors.T.numpy() / np.array([[1.0], [100.0], [10.0]]) + np.testing.assert_allclose(image.get_array(), expected) + assert colorbar.get_ylabel() == "Error / state_std" + + plt.close(fig) + + +def test_plot_error_heatmap_diff_std_normalization(): + """diff_std mode: colors are Error / physical diff_std.""" + errors = torch.tensor( + [ + [1.0, 100.0, 10.0], + [2.0, 80.0, 5.0], + [3.0, 60.0, 2.5], + ] + ) + datastore = HeatmapDatastore( + n_vars=errors.shape[1], + state_std=[1.0, 100.0, 10.0], + state_diff_std_standardized=[1.0, 2.0, 0.5], + ) + + fig = vis.plot_error_heatmap( + errors, datastore=datastore, normalization="diff_std" + ) + ax = fig.axes[0] + image = ax.images[0] + colorbar = fig.axes[1] + + # physical diff_std = state_std * state_diff_std_standardized = [1, 200, 5] + expected = errors.T.numpy() / np.array([[1.0], [200.0], [5.0]]) + np.testing.assert_allclose(image.get_array(), expected) + assert colorbar.get_ylabel() == "Error / physical diff_std" plt.close(fig) -def test_plot_error_heatmap_uses_absolute_scale_without_stats(): - """Without standardization stats the colorbar should stay in raw units.""" +def test_plot_error_heatmap_falls_back_to_per_var_scale_without_stats(): + """Colors fall back to per-variable max when stats are unavailable.""" class NoStatsHeatmapDatastore(HeatmapDatastore): def get_standardization_dataarray(self, category): @@ -218,17 +247,45 @@ def get_standardization_dataarray(self, category): errors = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) datastore = NoStatsHeatmapDatastore(n_vars=errors.shape[1]) - with pytest.warns(UserWarning, match="falling back to absolute scale"): + with pytest.warns(UserWarning, match="falling back to per-variable scale"): fig = vis.plot_error_heatmap(errors, datastore=datastore) + ax = fig.axes[0] + image = ax.images[0] colorbar = fig.axes[1] - assert colorbar.get_ylabel() == "Absolute scale" + # errors_np after transpose: var0=[1,3], var1=[2,4]; max per var: [3,4] + expected = np.array([[1 / 3, 3 / 3], [2 / 4, 4 / 4]]) + np.testing.assert_allclose(image.get_array(), expected) + assert "[fallback]" in colorbar.get_ylabel() plt.close(fig) -def test_plot_error_heatmap_uses_state_std_only_when_diff_std_absent(): - """Without diff std the heatmap should fall back to state-std scaling.""" +def test_plot_error_heatmap_diff_std_falls_back_when_diff_std_absent(): + """diff_std mode falls back to per-var max when diff_std is missing.""" + + class StateStdOnlyDatastore(HeatmapDatastore): + def get_standardization_dataarray(self, category): + return ( + super() + .get_standardization_dataarray(category) + .drop_vars("state_diff_std_standardized") + ) + + errors = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + datastore = StateStdOnlyDatastore(n_vars=2, state_std=[2.0, 1.0]) + + with pytest.warns(UserWarning, match="falling back to per-variable scale"): + fig = vis.plot_error_heatmap( + errors, datastore=datastore, normalization="diff_std" + ) + colorbar = fig.axes[1] + assert "[fallback]" in colorbar.get_ylabel() + plt.close(fig) + + +def test_plot_error_heatmap_state_std_ignores_diff_std(): + """state_std mode is unaffected by presence or absence of diff_std.""" class StateStdOnlyDatastore(HeatmapDatastore): def get_standardization_dataarray(self, category): @@ -239,12 +296,17 @@ def get_standardization_dataarray(self, category): ) errors = torch.tensor([[2.0, 4.0], [1.0, 3.0]]) + datastore = StateStdOnlyDatastore(n_vars=2, state_std=[2.0, 1.0]) fig = vis.plot_error_heatmap( - errors, - datastore=StateStdOnlyDatastore(n_vars=2, state_std=[2.0, 1.0]), + errors, datastore=datastore, normalization="state_std" ) + ax = fig.axes[0] + image = ax.images[0] - assert fig.axes[1].get_ylabel() == "Relative scale (state stds)" + # errors_np after transpose: var0=[2,1], var1=[4,3]; state_std=[2,1] + expected = np.array([[2 / 2, 1 / 2], [4 / 1, 3 / 1]]) + np.testing.assert_allclose(image.get_array(), expected) + assert fig.axes[1].get_ylabel() == "Error / state_std" plt.close(fig)