diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 06f2f6d3..ca1779ae 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -167,7 +167,12 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): # Normalize all errors to [0,1] for color map max_errors = errors_np.max(axis=1) # d_f - errors_norm = errors_np / np.expand_dims(max_errors, axis=1) + errors_norm = np.divide( + errors_np, + max_errors[:, None], + out=np.zeros_like(errors_np), + where=max_errors[:, None] != 0, + ) time_step_int, time_step_unit = utils.get_integer_time(step_length) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index ba03857c..c37cc828 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -125,6 +125,27 @@ def test_plot_error_map() -> None: assert len(ax.texts) == pred_steps * d_f +def test_plot_error_map_zero_row() -> None: + """Regression test: zero-error variable should not produce NaNs.""" + datastore = init_datastore_example("dummydata") + d_f = len(datastore.get_vars_names(category="state")) + pred_steps = 3 + + errors = torch.ones((pred_steps, d_f), dtype=torch.float32) + errors[:, 0] = 0.0 + + fig = vis.plot_error_map( + errors=errors, + datastore=datastore, + title="Zero Error Row", + ) + + ax = fig.axes[0] + plotted = np.asarray(ax.images[0].get_array()) + + assert np.isfinite(plotted).all() + + def test_plot_spatial_error() -> None: """Check that plot_spatial_error runs without error and returns a Figure.""" datastore = init_datastore_example("dummydata")