Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 83 additions & 47 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,62 +119,96 @@ def _get_feature_scale(


def _get_heatmap_color_values(
errors_np: np.ndarray, datastore: BaseRegularGridDatastore
) -> tuple[np.ndarray, str]:
errors_np: np.ndarray,
datastore: BaseRegularGridDatastore,
normalization: str,
) -> tuple[np.ndarray, str, matplotlib.colors.Colormap]:
"""
Normalize heatmap colors to a cross-variable relative scale.
Normalize heatmap colors according to `normalization`.

Returns a 3-tuple: (color_values, colorbar_label, cmap).
The returned array drives the colormap only; the numeric annotations in the
heatmap remain the original (physical-unit) values passed in `errors_np`.

Scaling logic:
- Prefer a relative scale based on datastore standardization stats.
Start with `state_std` (per-variable climatological std).
- If `state_diff_std_standardized` is available, also fold that in to
represent error relative to typical one-step variability.
- If any required stats are missing or invalid, fall back to absolute
scaling using the raw error values.
Both modes fall back to per-variable max normalization when their required
stat is unavailable, appending "[fallback]" to the colorbar label.
"""
try:
ds_state_stats = datastore.get_standardization_dataarray(
category="state"

def _per_var_fallback():
max_err = errors_np.max(axis=1, keepdims=True)
safe = np.where(max_err > np.finfo(float).eps, max_err, 1.0)
return (
errors_np / safe,
"Per-variable scale (relative to max error) [fallback]",
_HEATMAP_CMAP,
)

try:
ds_stats = datastore.get_standardization_dataarray(category="state")
except (AttributeError, KeyError, TypeError, ValueError) as exc:
warnings.warn(
f"Could not load standardization stats ({exc}); "
"falling back to absolute scale.",
"falling back to per-variable scale.",
UserWarning,
stacklevel=3,
)
return errors_np, "Absolute scale"
return _per_var_fallback()

n_vars = errors_np.shape[0]
state_std = _get_feature_scale(ds_state_stats, "state_std", n_vars)
if state_std is None:
warnings.warn(
"State standardization stats are unavailable; "
"falling back to absolute scale.",
UserWarning,
stacklevel=3,

if normalization == "state_std":
state_std = _get_feature_scale(ds_stats, "state_std", n_vars)
if state_std is None:
warnings.warn(
"state_std unavailable; falling back to per-variable scale.",
UserWarning,
stacklevel=3,
)
return _per_var_fallback()
safe_std = np.where(
np.isfinite(state_std) & (state_std > np.finfo(float).eps),
state_std,
1.0,
)
return errors_np, "Absolute scale"
return errors_np / safe_std[:, None], "Error / state_std", _HEATMAP_CMAP

scale = state_std
colorbar_label = "Relative scale (state stds)"
if normalization == "diff_std":
diff_std_std = _get_feature_scale(
ds_stats, "state_diff_std_standardized", n_vars
)
if diff_std_std is None:
warnings.warn(
"state_diff_std_standardized unavailable; "
"falling back to per-variable scale.",
UserWarning,
stacklevel=3,
)
return _per_var_fallback()
state_std = _get_feature_scale(ds_stats, "state_std", n_vars)
if state_std is None:
warnings.warn(
"state_std unavailable (needed to recover physical diff_std); "
"falling back to per-variable scale.",
UserWarning,
stacklevel=3,
)
return _per_var_fallback()
scale = state_std * diff_std_std # physical diff_std
safe = np.where(
np.isfinite(scale) & (np.abs(scale) > np.finfo(float).eps),
scale,
1.0,
)
return (
errors_np / safe[:, None],
"Error / physical diff_std",
_HEATMAP_CMAP,
)

state_diff_std_standardized = _get_feature_scale(
ds_state_stats, "state_diff_std_standardized", n_vars
raise ValueError(
f"Unknown normalization {normalization!r}; "
"expected 'state_std' or 'diff_std'."
)
if state_diff_std_standardized is not None:
scale = scale * state_diff_std_standardized
colorbar_label = "Error / Std(1-step change)"

safe_scale = np.where(
np.isfinite(scale) & (np.abs(scale) > np.finfo(float).eps),
scale,
1.0,
)
return errors_np / safe_scale[:, None], colorbar_label


def _get_annotation_text_color(
Expand Down Expand Up @@ -319,6 +353,7 @@ def plot_error_heatmap(
errors,
datastore: BaseRegularGridDatastore,
title=None,
normalization: str = "state_std",
):
"""
Plot a heatmap of errors for state variables across forecast lead times.
Expand All @@ -332,32 +367,33 @@ def plot_error_heatmap(
Datastore providing step length and variable metadata.
title : str, optional
Optional title for the figure.
normalization : {"state_std", "diff_std"}, default "state_std"
Color scaling mode. "state_std" divides by climatological std;
"diff_std" divides by the typical one-step change magnitude. Both fall
back to per-variable max error when the required stats are missing.

Notes
-----
The heatmap colormap is driven by a relative cross-variable scale derived
from datastore standardization stats (see `_get_heatmap_color_values`).
If those stats are unavailable, the plot falls back to absolute scaling
on the raw `errors` values.
Color scaling is controlled by `normalization`; see
`_get_heatmap_color_values` for the full fallback logic. When stats are
unavailable the colorbar label includes "[fallback]".
"""
errors_np = _to_heatmap_matrix(errors)
d_f, pred_steps = errors_np.shape
step_length = datastore.step_length

time_step_int, time_step_unit = utils.get_integer_time(step_length)
layout = _compute_heatmap_layout(n_rows=d_f, n_cols=pred_steps)
color_values_np, colorbar_label = _get_heatmap_color_values(
errors_np, datastore
color_values_np, colorbar_label, heatmap_cmap = _get_heatmap_color_values(
errors_np, datastore, normalization
)

finite_color_values = color_values_np[np.isfinite(color_values_np)]
if finite_color_values.size == 0:
vmin, vmax = 0.0, 1.0
else:
vmin = float(finite_color_values.min())
vmax = float(finite_color_values.max())
if vmin >= 0.0:
vmin = 0.0
vmin = min(0.0, vmin)
if np.isclose(vmin, vmax):
vmax = vmin + 1.0

Expand All @@ -368,7 +404,7 @@ def plot_error_heatmap(

im = ax.imshow(
color_values_np,
cmap=_HEATMAP_CMAP,
cmap=heatmap_cmap,
vmin=vmin,
vmax=vmax,
interpolation="none",
Expand Down
98 changes: 80 additions & 18 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def test_plot_error_map() -> None:
assert colorbar_ax.get_ylabel() != ""


def test_plot_error_heatmap_uses_relative_color_scale():
"""Heatmap colors should compare relative magnitudes across variables."""
def test_plot_error_heatmap_state_std_normalization():
"""state_std mode: colors are Error / state_std."""
errors = torch.tensor(
[
[1.0, 100.0, 10.0],
Expand All @@ -193,23 +193,52 @@ def test_plot_error_heatmap_uses_relative_color_scale():
state_diff_std_standardized=[1.0, 2.0, 0.5],
)

fig = vis.plot_error_heatmap(errors, datastore=datastore)
fig = vis.plot_error_heatmap(
errors, datastore=datastore, normalization="state_std"
)
ax = fig.axes[0]
image = ax.images[0]
colorbar = fig.axes[1]

expected_color_values = errors.T.numpy() / np.array([[1.0], [200.0], [5.0]])
np.testing.assert_allclose(image.get_array(), expected_color_values)
assert image.norm.vmin == 0.0
assert image.norm.vmax == pytest.approx(expected_color_values.max())
assert len(fig.axes) == 2
assert colorbar.get_ylabel() == "Error / Std(1-step change)"
expected = errors.T.numpy() / np.array([[1.0], [100.0], [10.0]])
np.testing.assert_allclose(image.get_array(), expected)
assert colorbar.get_ylabel() == "Error / state_std"

plt.close(fig)


def test_plot_error_heatmap_diff_std_normalization():
"""diff_std mode: colors are Error / physical diff_std."""
errors = torch.tensor(
[
[1.0, 100.0, 10.0],
[2.0, 80.0, 5.0],
[3.0, 60.0, 2.5],
]
)
datastore = HeatmapDatastore(
n_vars=errors.shape[1],
state_std=[1.0, 100.0, 10.0],
state_diff_std_standardized=[1.0, 2.0, 0.5],
)

fig = vis.plot_error_heatmap(
errors, datastore=datastore, normalization="diff_std"
)
ax = fig.axes[0]
image = ax.images[0]
colorbar = fig.axes[1]

# physical diff_std = state_std * state_diff_std_standardized = [1, 200, 5]
expected = errors.T.numpy() / np.array([[1.0], [200.0], [5.0]])
np.testing.assert_allclose(image.get_array(), expected)
assert colorbar.get_ylabel() == "Error / physical diff_std"

plt.close(fig)


def test_plot_error_heatmap_uses_absolute_scale_without_stats():
"""Without standardization stats the colorbar should stay in raw units."""
def test_plot_error_heatmap_falls_back_to_per_var_scale_without_stats():
"""Colors fall back to per-variable max when stats are unavailable."""

class NoStatsHeatmapDatastore(HeatmapDatastore):
def get_standardization_dataarray(self, category):
Expand All @@ -218,17 +247,45 @@ def get_standardization_dataarray(self, category):
errors = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
datastore = NoStatsHeatmapDatastore(n_vars=errors.shape[1])

with pytest.warns(UserWarning, match="falling back to absolute scale"):
with pytest.warns(UserWarning, match="falling back to per-variable scale"):
fig = vis.plot_error_heatmap(errors, datastore=datastore)
ax = fig.axes[0]
image = ax.images[0]
colorbar = fig.axes[1]

assert colorbar.get_ylabel() == "Absolute scale"
# errors_np after transpose: var0=[1,3], var1=[2,4]; max per var: [3,4]
expected = np.array([[1 / 3, 3 / 3], [2 / 4, 4 / 4]])
np.testing.assert_allclose(image.get_array(), expected)
assert "[fallback]" in colorbar.get_ylabel()

plt.close(fig)


def test_plot_error_heatmap_uses_state_std_only_when_diff_std_absent():
"""Without diff std the heatmap should fall back to state-std scaling."""
def test_plot_error_heatmap_diff_std_falls_back_when_diff_std_absent():
"""diff_std mode falls back to per-var max when diff_std is missing."""

class StateStdOnlyDatastore(HeatmapDatastore):
def get_standardization_dataarray(self, category):
return (
super()
.get_standardization_dataarray(category)
.drop_vars("state_diff_std_standardized")
)

errors = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
datastore = StateStdOnlyDatastore(n_vars=2, state_std=[2.0, 1.0])

with pytest.warns(UserWarning, match="falling back to per-variable scale"):
fig = vis.plot_error_heatmap(
errors, datastore=datastore, normalization="diff_std"
)
colorbar = fig.axes[1]
assert "[fallback]" in colorbar.get_ylabel()
plt.close(fig)


def test_plot_error_heatmap_state_std_ignores_diff_std():
"""state_std mode is unaffected by presence or absence of diff_std."""

class StateStdOnlyDatastore(HeatmapDatastore):
def get_standardization_dataarray(self, category):
Expand All @@ -239,12 +296,17 @@ def get_standardization_dataarray(self, category):
)

errors = torch.tensor([[2.0, 4.0], [1.0, 3.0]])
datastore = StateStdOnlyDatastore(n_vars=2, state_std=[2.0, 1.0])
fig = vis.plot_error_heatmap(
errors,
datastore=StateStdOnlyDatastore(n_vars=2, state_std=[2.0, 1.0]),
errors, datastore=datastore, normalization="state_std"
)
ax = fig.axes[0]
image = ax.images[0]

assert fig.axes[1].get_ylabel() == "Relative scale (state stds)"
# errors_np after transpose: var0=[2,1], var1=[4,3]; state_std=[2,1]
expected = np.array([[2 / 2, 1 / 2], [4 / 1, 3 / 1]])
np.testing.assert_allclose(image.get_array(), expected)
assert fig.axes[1].get_ylabel() == "Error / state_std"
plt.close(fig)


Expand Down
Loading