diff --git a/CHANGELOG.md b/CHANGELOG.md index ea899ff5..c67c37a7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ This release introduces new features including GIF animation support, wandb run - Infer spatial coordinate names for MDPDatastore (rather than assuming names `x` and `y`), allows for e.g. lat/lon regular grids [\#169](https://github.com/mllam/neural-lam/pull/169) @leifdenby ### Fixed +- Fix `WeatherDataset` boundary handling for out-of-range indexing and forecast-mode forcing horizon validation to prevent malformed samples [\#312](https://github.com/mllam/neural-lam/pull/312) - Fix validation crash in `plot_error_map` and resolve DDP NCCL initialization error on single-device setups [\#193](https://github.com/mllam/neural-lam/pull/193) @AdityaKumarSethia diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 5547fdd4..3257e7f3 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -85,19 +85,6 @@ def __init__( stacklevel=2, ) - # check that with the provided data-arrays and ar_steps that we have a - # non-zero amount of samples - if self.__len__() <= 0 and self.da_state is not None: - raise ValueError( - "The provided datastore only provides " - f"{len(self.da_state.time)} total time steps, which is too few " - "to create a single sample for the WeatherDataset " - f"configuration used in the `{split}` split. You could try " - "either reducing the number of autoregressive steps " - "(`ar_steps`) and/or the forcing window size " - "(`num_past_forcing_steps` and `num_future_forcing_steps`)" - ) - # Check the dimensions and their ordering parts = dict(state=self.da_state) if self.da_forcing is not None: @@ -116,6 +103,27 @@ def __init__( "transpose the data in `BaseDatastore.get_dataarray`?" ) + self._validate_dataset_configuration() + self._dataset_len = self._compute_dataset_len() + + # check that with the provided data-arrays and ar_steps that we have a + # non-zero amount of samples + if self._dataset_len <= 0 and self.da_state is not None: + time_dim = ( + self.da_state.analysis_time + if self.datastore.is_forecast + else self.da_state.time + ) + raise ValueError( + "The provided datastore only provides " + f"{len(time_dim)} total time steps, which is too few " + "to create a single sample for the WeatherDataset " + f"configuration used in the `{split}` split. You could try " + "either reducing the number of autoregressive steps " + "(`ar_steps`) and/or the forcing window size " + "(`num_past_forcing_steps` and `num_future_forcing_steps`)" + ) + # Set up for standardization # TODO: This will become part of ar_model.py soon! self.standardize = standardize @@ -150,56 +158,115 @@ def __init__( else: self.forcing_std_safe = None - def _compute_std_safe(self, std: xr.DataArray, feature: str): - eps = np.finfo(std.dtype).eps - if bool((std <= eps).any()): - logger.warning( - f"Some {feature} features have near-zero std and will be " - "standardized using machine epsilon to avoid NaN." + def _validate_dataset_configuration(self): + """Validate dataset structure once during initialization.""" + if not self.datastore.is_forecast: + return + + required_state_forecast_steps = ( + max(2, self.num_past_forcing_steps) + self.ar_steps + ) + n_state_forecast_steps = self.da_state.elapsed_forecast_duration.size + if n_state_forecast_steps < required_state_forecast_steps: + raise ValueError( + "The number of forecast steps available " + f"({n_state_forecast_steps}) is less than the required " + f"{required_state_forecast_steps} " + f"(max(2, num_past_forcing_steps=" + f"{self.num_past_forcing_steps}) + ar_steps=" + f"{self.ar_steps}) for creating a sample with initial " + "and target states." ) - return std.where(std > eps, other=eps) - def __len__(self): - if self.datastore.is_forecast: - # for now we simply create a single sample for each analysis time - # and then take the first (2 + ar_steps) forecast times. - # If the datastore returns an ensemble of state realisations and - # `load_single_member=False`, each ensemble member is exposed as an - # independent sample by scaling the base dataset length below. - - # check that there are enough forecast steps available to create - # samples given the number of autoregressive steps requested - n_forecast_steps = self.da_state.elapsed_forecast_duration.size - if n_forecast_steps < 2 + self.ar_steps: - raise ValueError( - "The number of forecast steps available " - f"({n_forecast_steps}) is less than the required " - f"2+ar_steps (2+{self.ar_steps}={2 + self.ar_steps}) for " - "creating a sample with initial and target states." - ) + if self.da_forcing is None: + return + if not np.array_equal( + self.da_state.analysis_time.values, + self.da_forcing.analysis_time.values, + ): + raise ValueError( + "State and forcing analysis times must match for " + "forecast-mode datasets." + ) + + if not np.array_equal( + self.da_state.elapsed_forecast_duration.values[ + :required_state_forecast_steps + ], + self.da_forcing.elapsed_forecast_duration.values[ + :required_state_forecast_steps + ], + ): + raise ValueError( + "State and forcing forecast lead times must match across " + "the state forecast horizon used for target alignment in " + "forecast-mode datasets." + ) + + n_forcing_forecast_steps = ( + self.da_forcing.elapsed_forecast_duration.size + ) + required_forcing_forecast_steps = ( + max(2, self.num_past_forcing_steps) + + self.ar_steps + + self.num_future_forcing_steps + ) + if n_forcing_forecast_steps < required_forcing_forecast_steps: + raise ValueError( + "The number of forcing forecast steps available " + f"({n_forcing_forecast_steps}) is less than the required " + f"{required_forcing_forecast_steps} " + f"(max(2, num_past_forcing_steps=" + f"{self.num_past_forcing_steps})" + f" + ar_steps={self.ar_steps} + " + f"num_future_forcing_steps=" + f"{self.num_future_forcing_steps}) " + "for constructing forcing windows." + ) + + def _compute_dataset_len(self): + """Compute dataset length without running structural validation.""" + if self.datastore.is_forecast: base_len = self.da_state.analysis_time.size else: - # Calculate the number of samples in the dataset n_samples = total - # time steps - (autoregressive steps + past forcing + future - # forcing) - #: - # Where: - # - total time steps: len(self.da_state.time) - # - autoregressive steps: self.ar_steps - # - past forcing: max(2, self.num_past_forcing_steps) (at least 2 - # time steps are required for the initial state) - # - future forcing: self.num_future_forcing_steps - base_len = ( + n_state_samples = ( len(self.da_state.time) - self.ar_steps - max(2, self.num_past_forcing_steps) - self.num_future_forcing_steps + + 1 ) + if self.da_forcing is None: + base_len = max(0, n_state_samples) + else: + n_forcing_samples = ( + len(self.da_forcing.time) + - self.ar_steps + - max(2, self.num_past_forcing_steps) + - self.num_future_forcing_steps + + 1 + ) + base_len = max(0, min(n_state_samples, n_forcing_samples)) + if self.datastore.is_ensemble and not self.load_single_member: return base_len * self.da_state.ensemble_member.size return base_len + def _compute_std_safe(self, std: xr.DataArray, feature: str): + eps = np.finfo(std.dtype).eps + if bool((std <= eps).any()): + logger.warning( + f"Some {feature} features have near-zero std and will be " + "standardized using machine epsilon to avoid NaN." + ) + return std.where(std > eps, other=eps) + + def __len__(self): + if hasattr(self, "_dataset_len"): + return self._dataset_len + return self._compute_dataset_len() + def _slice_state_time(self, da_state, idx, n_steps: int): """ Produce a time slice of the given dataarray `da_state` (state) starting @@ -502,6 +569,17 @@ def __getitem__(self, idx): the target steps. """ + dataset_len = self._dataset_len + + # Match Python sequence semantics for negative indexing. + if idx < 0: + idx += dataset_len + if idx < 0 or idx >= dataset_len: + raise IndexError( + f"Index {idx} is out of bounds for dataset of size " + f"{dataset_len}" + ) + ( da_init_states, da_target_states, diff --git a/tests/test_datasets.py b/tests/test_datasets.py index b0e03556..ccdc1f7d 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -5,6 +5,7 @@ import numpy as np import pytest import torch +import xarray as xr from torch.utils.data import DataLoader # First-party @@ -18,6 +19,45 @@ from tests.dummy_datastore import DummyDatastore, EnsembleDummyDatastore +class ForecastArrayDatastore(DummyDatastore): + is_forecast = True + + def __init__(self, da_state, da_forcing): + super().__init__(n_grid_points=1, n_timesteps=1) + self._state_da = da_state + self._forcing_da = da_forcing + self.is_ensemble = "ensemble_member" in da_state.dims + self.has_ensemble_forcing = ( + da_forcing is not None and "ensemble_member" in da_forcing.dims + ) + + def get_dataarray(self, category, split, **kwargs): + if category == "state": + return self._state_da + if category == "forcing": + return self._forcing_da + return super().get_dataarray(category, split, **kwargs) + + +def make_forecast_dataset( + da_state, + da_forcing, + *, + ar_steps, + num_past_forcing_steps, + num_future_forcing_steps, +): + datastore = ForecastArrayDatastore(da_state=da_state, da_forcing=da_forcing) + return WeatherDataset( + datastore=datastore, + split="train", + ar_steps=ar_steps, + num_past_forcing_steps=num_past_forcing_steps, + num_future_forcing_steps=num_future_forcing_steps, + standardize=False, + ) + + @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_dataset_item_shapes(datastore_name): """Check that the `datastore.get_dataarray` method is implemented. @@ -227,12 +267,12 @@ def _create_graph(): @pytest.mark.parametrize( "dataset_config", [ - {"past": 0, "future": 0, "ar_steps": 1, "exp_len_reduction": 3}, - {"past": 2, "future": 0, "ar_steps": 1, "exp_len_reduction": 3}, - {"past": 0, "future": 2, "ar_steps": 1, "exp_len_reduction": 5}, - {"past": 4, "future": 0, "ar_steps": 1, "exp_len_reduction": 5}, - {"past": 0, "future": 0, "ar_steps": 5, "exp_len_reduction": 7}, - {"past": 3, "future": 3, "ar_steps": 2, "exp_len_reduction": 8}, + {"past": 0, "future": 0, "ar_steps": 1, "exp_len_reduction": 2}, + {"past": 2, "future": 0, "ar_steps": 1, "exp_len_reduction": 2}, + {"past": 0, "future": 2, "ar_steps": 1, "exp_len_reduction": 4}, + {"past": 4, "future": 0, "ar_steps": 1, "exp_len_reduction": 4}, + {"past": 0, "future": 0, "ar_steps": 5, "exp_len_reduction": 6}, + {"past": 3, "future": 3, "ar_steps": 2, "exp_len_reduction": 7}, ], ) def test_dataset_length(dataset_config): @@ -458,6 +498,504 @@ def test_standardization_with_zero_std(): ).any(), "NaN found after _compute_std_safe" +def test_dataset_out_of_bounds_indexing_raises(): + """Ensure out-of-range indexing fails instead of returning bad samples.""" + datastore = DummyDatastore(n_grid_points=4, n_timesteps=10) + dataset = WeatherDataset( + datastore=datastore, + split="train", + ar_steps=2, + num_past_forcing_steps=1, + num_future_forcing_steps=1, + ) + + # In-bounds indices work, including Python-style negative indexing. + dataset[0] + dataset[len(dataset) - 1] + dataset[-1] + + # Out-of-bounds indices must fail explicitly. + with pytest.raises(IndexError): + dataset[len(dataset)] + with pytest.raises(IndexError): + dataset[len(dataset) + 1] + with pytest.raises(IndexError): + dataset[-len(dataset) - 1] + + +def test_negative_indexing_does_not_call_len_in_getitem(): + class LenBombDataset(WeatherDataset): + def __len__(self): + raise AssertionError("__getitem__ should use cached dataset length") + + datastore = DummyDatastore(n_grid_points=4, n_timesteps=10) + dataset = LenBombDataset( + datastore=datastore, + split="train", + ar_steps=2, + num_past_forcing_steps=1, + num_future_forcing_steps=1, + ) + + dataset[-1] + + +def test_forecast_len_raises_when_forcing_horizon_too_short(): + analysis_time = np.array( + ["2021-01-01T00:00:00", "2021-01-01T01:00:00"], + dtype="datetime64[ns]", + ) + elapsed = np.arange(5, dtype="timedelta64[h]").astype("timedelta64[ns]") + + da_state = xr.DataArray( + np.zeros((2, 5, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "state_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": elapsed, + "grid_index": [0], + "state_feature": ["state_feat_0"], + }, + ) + da_forcing = xr.DataArray( + np.zeros((2, 5, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "forcing_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": elapsed, + "grid_index": [0], + "forcing_feature": ["forcing_feat_0"], + }, + ) + + with pytest.raises( + ValueError, + match="forecast lead times must match|forcing forecast steps", + ): + make_forecast_dataset( + da_state, + da_forcing, + ar_steps=2, + num_past_forcing_steps=1, + num_future_forcing_steps=2, + ) + + +def test_forecast_len_raises_when_state_horizon_too_short_for_past_forcing(): + analysis_time = np.array( + ["2021-01-01T00:00:00", "2021-01-01T01:00:00"], + dtype="datetime64[ns]", + ) + elapsed = np.arange(4, dtype="timedelta64[h]").astype("timedelta64[ns]") + + da_state = xr.DataArray( + np.zeros((2, 4, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "state_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": elapsed, + "grid_index": [0], + "state_feature": ["state_feat_0"], + }, + ) + + with pytest.raises(ValueError, match="initial and target states"): + make_forecast_dataset( + da_state, + None, + ar_steps=1, + num_past_forcing_steps=4, + num_future_forcing_steps=0, + ) + + +def test_forecast_len_accepts_exact_state_horizon_for_past_forcing(): + analysis_time = np.array( + ["2021-01-01T00:00:00", "2021-01-01T01:00:00"], + dtype="datetime64[ns]", + ) + elapsed = np.arange(5, dtype="timedelta64[h]").astype("timedelta64[ns]") + + da_state = xr.DataArray( + np.zeros((2, 5, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "state_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": elapsed, + "grid_index": [0], + "state_feature": ["state_feat_0"], + }, + ) + + dataset = make_forecast_dataset( + da_state, + None, + ar_steps=1, + num_past_forcing_steps=4, + num_future_forcing_steps=0, + ) + assert len(dataset) == 2 + + +def test_forecast_len_accepts_exact_forcing_horizon(): + analysis_time = np.array( + ["2021-01-01T00:00:00", "2021-01-01T01:00:00"], + dtype="datetime64[ns]", + ) + elapsed = np.arange(6, dtype="timedelta64[h]").astype("timedelta64[ns]") + + da_state = xr.DataArray( + np.zeros((2, 6, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "state_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": elapsed, + "grid_index": [0], + "state_feature": ["state_feat_0"], + }, + ) + da_forcing = xr.DataArray( + np.zeros((2, 6, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "forcing_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": elapsed, + "grid_index": [0], + "forcing_feature": ["forcing_feat_0"], + }, + ) + + dataset = make_forecast_dataset( + da_state, + da_forcing, + ar_steps=2, + num_past_forcing_steps=1, + num_future_forcing_steps=2, + ) + assert len(dataset) == 2 + + +def test_forecast_len_accepts_longer_forcing_horizon_with_matching_prefix(): + analysis_time = np.array( + ["2021-01-01T00:00:00", "2021-01-01T01:00:00"], + dtype="datetime64[ns]", + ) + state_elapsed = np.arange(4, dtype="timedelta64[h]").astype( + "timedelta64[ns]" + ) + forcing_elapsed = np.arange(6, dtype="timedelta64[h]").astype( + "timedelta64[ns]" + ) + + da_state = xr.DataArray( + np.zeros((2, 4, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "state_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": state_elapsed, + "grid_index": [0], + "state_feature": ["state_feat_0"], + }, + ) + da_forcing = xr.DataArray( + np.zeros((2, 6, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "forcing_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": forcing_elapsed, + "grid_index": [0], + "forcing_feature": ["forcing_feat_0"], + }, + ) + + dataset = make_forecast_dataset( + da_state, + da_forcing, + ar_steps=2, + num_past_forcing_steps=1, + num_future_forcing_steps=2, + ) + + assert len(dataset) == 2 + + +def test_forecast_len_raises_when_forcing_shorter_than_state_horizon(): + analysis_time = np.array( + ["2021-01-01T00:00:00", "2021-01-01T01:00:00"], + dtype="datetime64[ns]", + ) + state_elapsed = np.arange(6, dtype="timedelta64[h]").astype( + "timedelta64[ns]" + ) + forcing_elapsed = np.arange(5, dtype="timedelta64[h]").astype( + "timedelta64[ns]" + ) + + da_state = xr.DataArray( + np.zeros((2, 6, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "state_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": state_elapsed, + "grid_index": [0], + "state_feature": ["state_feat_0"], + }, + ) + da_forcing = xr.DataArray( + np.zeros((2, 5, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "forcing_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": forcing_elapsed, + "grid_index": [0], + "forcing_feature": ["forcing_feat_0"], + }, + ) + + with pytest.raises( + ValueError, + match="forecast lead times must match|forcing forecast steps", + ): + make_forecast_dataset( + da_state, + da_forcing, + ar_steps=2, + num_past_forcing_steps=1, + num_future_forcing_steps=2, + ) + + +def test_forecast_len_raises_when_analysis_times_do_not_match(): + state_analysis_time = np.array( + ["2021-01-01T00:00:00", "2021-01-01T01:00:00"], + dtype="datetime64[ns]", + ) + forcing_analysis_time = np.array( + ["2021-01-01T00:00:00"], + dtype="datetime64[ns]", + ) + elapsed = np.arange(5, dtype="timedelta64[h]").astype("timedelta64[ns]") + + da_state = xr.DataArray( + np.zeros((2, 5, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "state_feature", + ), + coords={ + "analysis_time": state_analysis_time, + "elapsed_forecast_duration": elapsed, + "grid_index": [0], + "state_feature": ["state_feat_0"], + }, + ) + da_forcing = xr.DataArray( + np.zeros((1, 5, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "forcing_feature", + ), + coords={ + "analysis_time": forcing_analysis_time, + "elapsed_forecast_duration": elapsed, + "grid_index": [0], + "forcing_feature": ["forcing_feat_0"], + }, + ) + + with pytest.raises(ValueError, match="analysis times must match"): + make_forecast_dataset( + da_state, + da_forcing, + ar_steps=2, + num_past_forcing_steps=1, + num_future_forcing_steps=1, + ) + + +def test_forecast_len_raises_when_forecast_lead_times_do_not_match(): + analysis_time = np.array( + ["2021-01-01T00:00:00", "2021-01-01T01:00:00"], + dtype="datetime64[ns]", + ) + state_elapsed = np.arange(5, dtype="timedelta64[h]").astype( + "timedelta64[ns]" + ) + forcing_elapsed = np.array([0, 2, 4, 6, 8], dtype="timedelta64[h]").astype( + "timedelta64[ns]" + ) + + da_state = xr.DataArray( + np.zeros((2, 5, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "state_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": state_elapsed, + "grid_index": [0], + "state_feature": ["state_feat_0"], + }, + ) + da_forcing = xr.DataArray( + np.zeros((2, 5, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "forcing_feature", + ), + coords={ + "analysis_time": analysis_time, + "elapsed_forecast_duration": forcing_elapsed, + "grid_index": [0], + "forcing_feature": ["forcing_feat_0"], + }, + ) + + with pytest.raises(ValueError, match="forecast lead times must match"): + make_forecast_dataset( + da_state, + da_forcing, + ar_steps=2, + num_past_forcing_steps=1, + num_future_forcing_steps=1, + ) + + +def test_weather_dataset_forecast_empty_split_raises_value_error(): + """Empty forecast splits should raise the intended user-facing error.""" + # Third-party + import xarray as xr + + class EmptyForecastDatastore(DummyDatastore): + is_forecast = True + + def get_dataarray(self, category, split, **kwargs): + if category == "state": + return xr.DataArray( + np.zeros((0, 3, 1, 1), dtype=np.float32), + dims=( + "analysis_time", + "elapsed_forecast_duration", + "grid_index", + "state_feature", + ), + coords={ + "analysis_time": np.array([], dtype="datetime64[ns]"), + "elapsed_forecast_duration": np.arange( + 3, dtype="timedelta64[h]" + ).astype("timedelta64[ns]"), + "grid_index": [0], + "state_feature": ["state_feat_0"], + }, + ) + if category == "forcing": + return None + return super().get_dataarray( + category=category, split=split, **kwargs + ) + + datastore = EmptyForecastDatastore(n_grid_points=4, n_timesteps=10) + + with pytest.raises(ValueError, match="0 total time steps"): + WeatherDataset( + datastore=datastore, + split="train", + ar_steps=1, + num_past_forcing_steps=1, + num_future_forcing_steps=0, + standardize=False, + ) + + +def test_analysis_len_limited_by_shorter_forcing_horizon(): + """Analysis-mode datasets must not expose samples whose forcing windows + overrun the available forcing time axis.""" + + class ShortForcingDatastore(DummyDatastore): + def get_dataarray(self, category, split, **kwargs): + da = super().get_dataarray(category=category, split=split, **kwargs) + if category == "forcing": + return da.isel(time=slice(None, -1)) + return da + + datastore = ShortForcingDatastore(n_grid_points=4, n_timesteps=7) + dataset = WeatherDataset( + datastore=datastore, + split="train", + ar_steps=2, + num_past_forcing_steps=1, + num_future_forcing_steps=2, + standardize=False, + ) + + assert len(dataset) == 1 + + _, _, forcing, _ = dataset[0] + assert not torch.isnan(forcing).any() + + with pytest.raises(IndexError): + dataset[1] + + def test_weather_dataset_no_forcing_standardize(): """Regression test: WeatherDataset must not raise AttributeError when the datastore has no forcing data and standardize=True (the default). @@ -493,8 +1031,5 @@ def get_dataarray(self, category, split, **kwargs): assert dataset.da_forcing_mean is None assert dataset.da_forcing_std is None - # Ensure we can still retrieve a sample (forcing tensor should be empty) - init_states, target_states, forcing, target_times = dataset[0] - assert ( - forcing.shape[-1] == 0 - ), "Expected zero forcing features when forcing is None" + _, _, forcing, _ = dataset[0] + assert forcing.shape[-1] == 0