Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None):

# Normalize all errors to [0,1] for color map
max_errors = errors_np.max(axis=1) # d_f
errors_norm = errors_np / np.expand_dims(max_errors, axis=1)
errors_norm = np.divide(
errors_np,
max_errors[:, None],
out=np.zeros_like(errors_np),
where=max_errors[:, None] != 0,
)

time_step_int, time_step_unit = utils.get_integer_time(step_length)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,27 @@ def test_plot_error_map() -> None:
assert len(ax.texts) == pred_steps * d_f


def test_plot_error_map_zero_row() -> None:
"""Regression test: zero-error variable should not produce NaNs."""
datastore = init_datastore_example("dummydata")
d_f = len(datastore.get_vars_names(category="state"))
pred_steps = 3

errors = torch.ones((pred_steps, d_f), dtype=torch.float32)
errors[:, 0] = 0.0

fig = vis.plot_error_map(
errors=errors,
datastore=datastore,
title="Zero Error Row",
)

ax = fig.axes[0]
plotted = np.asarray(ax.images[0].get_array())
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really get back the normalized errors? I think this code would need some explanations as it is pretty much unpacking the produced figure. Hard to understand without knowing matplotlib internals.


assert np.isfinite(plotted).all()


def test_plot_spatial_error() -> None:
"""Check that plot_spatial_error runs without error and returns a Figure."""
datastore = init_datastore_example("dummydata")
Expand Down