diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 8a9cc8253..a1a1ab703 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -134,12 +134,13 @@ def __len__(self): # check that there are enough forecast steps available to create # samples given the number of autoregressive steps requested n_forecast_steps = self.da_state.elapsed_forecast_duration.size - if n_forecast_steps < 2 + self.ar_steps: + required_steps = self.ar_steps + max(2, self.num_past_forcing_steps) + self.num_future_forcing_steps + if n_forecast_steps < required_steps: raise ValueError( "The number of forecast steps available " f"({n_forecast_steps}) is less than the required " - f"2+ar_steps (2+{self.ar_steps}={2 + self.ar_steps}) for " - "creating a sample with initial and target states." + f"({required_steps}) for " + "creating a sample with initial and target states and forcing windows." ) return self.da_state.analysis_time.size diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 419aece09..a24acea9c 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -5,6 +5,7 @@ import numpy as np import pytest import torch +import xarray as xr from torch.utils.data import DataLoader # First-party @@ -259,3 +260,173 @@ def test_dataset_length(dataset_config): # Check that we can actually get last and first sample dataset[0] dataset[expected_len - 1] + + +class ForecastDummyDatastore: + """Minimal forecast-mode datastore for testing WeatherDataset bounds + validation. + + Simulates a datastore where data is organised as + (analysis_time, elapsed_forecast_duration, grid_index, feature) rather + than a flat time axis. + + Parameters + ---------- + n_analysis_times : int + Number of analysis times (i.e. number of forecast initialisations). + n_forecast_steps : int + Number of elapsed forecast steps available per analysis time. + """ + + is_forecast = True + is_ensemble = False + coords_projection = None + num_grid_points = 1 + config = {} + + def __init__(self, n_analysis_times: int, n_forecast_steps: int): + from datetime import timedelta + from pathlib import Path + + import pandas as pd + + self._n_forecast_steps = n_forecast_steps + self._step_length = timedelta(hours=1) + self.root_path = Path("dummy_forecast") + + analysis_times = pd.date_range( + "2020-01-01", periods=n_analysis_times, freq="1h" + ) + forecast_durations = pd.to_timedelta( + [timedelta(hours=h) for h in range(n_forecast_steps)] + ) + + # Build a minimal (analysis_time, elapsed_forecast_duration, + # grid_index, state_feature) DataArray that WeatherDataset.__len__ + # can inspect. + data = np.zeros( + (n_analysis_times, n_forecast_steps, self.num_grid_points, 1) + ) + self.da_state = xr.DataArray( + data, + dims=[ + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "state_feature", + ], + coords={ + "analysis_time": analysis_times, + "elapsed_forecast_duration": forecast_durations, + }, + ) + + @property + def step_length(self): + return self._step_length + + def get_num_data_vars(self, category): + return 1 + + def get_dataarray(self, category, split, standardize=False): + if category == "state": + return self.da_state + return None # no forcing for simplicity + + def get_standardization_dataarray(self, category): + raise NotImplementedError() + + def get_vars_units(self, category): + return ["-"] + + def get_vars_names(self, category): + return [f"{category}_var_0"] + + def get_vars_long_names(self, category): + return [f"{category} variable 0"] + + def expected_dim_order(self, category=None): + return ( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + f"{category}_feature", + ) + + @property + def boundary_mask(self): + return xr.DataArray(np.zeros(1, dtype=int), dims=["grid_index"]) + + +@pytest.mark.parametrize( + "ar_steps,num_past_forcing_steps,num_future_forcing_steps," + "n_forecast_steps,should_raise", + [ + # Passes: exactly enough steps with no forcing window. + # Required = ar_steps + max(2, past=0) + future=0 = 3 + 2 + 0 = 5 + (3, 0, 0, 5, False), + # Passes: enough steps once forcing windows are accounted for. + # Required = 3 + max(2, 1) + 2 = 3 + 2 + 2 = 7 + (3, 1, 2, 7, False), + # Fails: the old check (2 + ar_steps = 5) would NOT catch this because + # n_forecast_steps=5 >= 5, but the forcing window pushes required to + # 3 + 2 + 2 = 7. This is the exact regression the fix addresses. + (3, 1, 2, 5, True), + # Fails: outright too few steps even without forcing. + # Required = 3 + 2 + 0 = 5, but only 4 available. + (3, 0, 0, 4, True), + ], +) +def test_forecast_dataset_bounds_validation( + ar_steps, + num_past_forcing_steps, + num_future_forcing_steps, + n_forecast_steps, + should_raise, +): + """Verify that WeatherDataset raises ValueError when forecast steps are + insufficient, correctly accounting for both ar_steps and forcing windows. + + This is a regression test for the bug where the bounds check only compared + against ``2 + ar_steps``, silently ignoring the extra steps consumed by + the forcing windows. With short forecast data, xarray would silently + truncate the slices and PyTorch Lightning's collate would later crash with + cryptic shape-mismatch errors. + + Parameters + ---------- + ar_steps : int + Number of autoregressive prediction steps. + num_past_forcing_steps : int + Number of past forcing steps in the window. + num_future_forcing_steps : int + Number of future forcing steps in the window. + n_forecast_steps : int + Number of forecast steps available in the dummy datastore. + should_raise : bool + Whether instantiating WeatherDataset is expected to raise ValueError. + """ + datastore = ForecastDummyDatastore( + n_analysis_times=4, n_forecast_steps=n_forecast_steps + ) + + if should_raise: + with pytest.raises(ValueError, match="forecast steps available"): + WeatherDataset( + datastore=datastore, + split="train", + ar_steps=ar_steps, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + standardize=False, + ) + else: + dataset = WeatherDataset( + datastore=datastore, + split="train", + ar_steps=ar_steps, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + standardize=False, + ) + assert len(dataset) > 0, "Dataset should contain at least one sample"