Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,13 @@ def __len__(self):
# check that there are enough forecast steps available to create
# samples given the number of autoregressive steps requested
n_forecast_steps = self.da_state.elapsed_forecast_duration.size
if n_forecast_steps < 2 + self.ar_steps:
required_steps = self.ar_steps + max(2, self.num_past_forcing_steps) + self.num_future_forcing_steps
if n_forecast_steps < required_steps:
raise ValueError(
"The number of forecast steps available "
f"({n_forecast_steps}) is less than the required "
f"2+ar_steps (2+{self.ar_steps}={2 + self.ar_steps}) for "
"creating a sample with initial and target states."
f"({required_steps}) for "
"creating a sample with initial and target states and forcing windows."
)

return self.da_state.analysis_time.size
Expand Down
171 changes: 171 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import pytest
import torch
import xarray as xr
from torch.utils.data import DataLoader

# First-party
Expand Down Expand Up @@ -259,3 +260,173 @@ def test_dataset_length(dataset_config):
# Check that we can actually get last and first sample
dataset[0]
dataset[expected_len - 1]


class ForecastDummyDatastore:
"""Minimal forecast-mode datastore for testing WeatherDataset bounds
validation.

Simulates a datastore where data is organised as
(analysis_time, elapsed_forecast_duration, grid_index, feature) rather
than a flat time axis.

Parameters
----------
n_analysis_times : int
Number of analysis times (i.e. number of forecast initialisations).
n_forecast_steps : int
Number of elapsed forecast steps available per analysis time.
"""

is_forecast = True
is_ensemble = False
coords_projection = None
num_grid_points = 1
config = {}

def __init__(self, n_analysis_times: int, n_forecast_steps: int):
from datetime import timedelta
from pathlib import Path

import pandas as pd

self._n_forecast_steps = n_forecast_steps
self._step_length = timedelta(hours=1)
self.root_path = Path("dummy_forecast")

analysis_times = pd.date_range(
"2020-01-01", periods=n_analysis_times, freq="1h"
)
forecast_durations = pd.to_timedelta(
[timedelta(hours=h) for h in range(n_forecast_steps)]
)

# Build a minimal (analysis_time, elapsed_forecast_duration,
# grid_index, state_feature) DataArray that WeatherDataset.__len__
# can inspect.
data = np.zeros(
(n_analysis_times, n_forecast_steps, self.num_grid_points, 1)
)
self.da_state = xr.DataArray(
data,
dims=[
"analysis_time",
"elapsed_forecast_duration",
"grid_index",
"state_feature",
],
coords={
"analysis_time": analysis_times,
"elapsed_forecast_duration": forecast_durations,
},
)

@property
def step_length(self):
return self._step_length

def get_num_data_vars(self, category):
return 1

def get_dataarray(self, category, split, standardize=False):
if category == "state":
return self.da_state
return None # no forcing for simplicity

def get_standardization_dataarray(self, category):
raise NotImplementedError()

def get_vars_units(self, category):
return ["-"]

def get_vars_names(self, category):
return [f"{category}_var_0"]

def get_vars_long_names(self, category):
return [f"{category} variable 0"]

def expected_dim_order(self, category=None):
return (
"analysis_time",
"elapsed_forecast_duration",
"grid_index",
f"{category}_feature",
)

@property
def boundary_mask(self):
return xr.DataArray(np.zeros(1, dtype=int), dims=["grid_index"])


@pytest.mark.parametrize(
"ar_steps,num_past_forcing_steps,num_future_forcing_steps,"
"n_forecast_steps,should_raise",
[
# Passes: exactly enough steps with no forcing window.
# Required = ar_steps + max(2, past=0) + future=0 = 3 + 2 + 0 = 5
(3, 0, 0, 5, False),
# Passes: enough steps once forcing windows are accounted for.
# Required = 3 + max(2, 1) + 2 = 3 + 2 + 2 = 7
(3, 1, 2, 7, False),
# Fails: the old check (2 + ar_steps = 5) would NOT catch this because
# n_forecast_steps=5 >= 5, but the forcing window pushes required to
# 3 + 2 + 2 = 7. This is the exact regression the fix addresses.
(3, 1, 2, 5, True),
# Fails: outright too few steps even without forcing.
# Required = 3 + 2 + 0 = 5, but only 4 available.
(3, 0, 0, 4, True),
],
)
def test_forecast_dataset_bounds_validation(
ar_steps,
num_past_forcing_steps,
num_future_forcing_steps,
n_forecast_steps,
should_raise,
):
"""Verify that WeatherDataset raises ValueError when forecast steps are
insufficient, correctly accounting for both ar_steps and forcing windows.

This is a regression test for the bug where the bounds check only compared
against ``2 + ar_steps``, silently ignoring the extra steps consumed by
the forcing windows. With short forecast data, xarray would silently
truncate the slices and PyTorch Lightning's collate would later crash with
cryptic shape-mismatch errors.

Parameters
----------
ar_steps : int
Number of autoregressive prediction steps.
num_past_forcing_steps : int
Number of past forcing steps in the window.
num_future_forcing_steps : int
Number of future forcing steps in the window.
n_forecast_steps : int
Number of forecast steps available in the dummy datastore.
should_raise : bool
Whether instantiating WeatherDataset is expected to raise ValueError.
"""
datastore = ForecastDummyDatastore(
n_analysis_times=4, n_forecast_steps=n_forecast_steps
)

if should_raise:
with pytest.raises(ValueError, match="forecast steps available"):
WeatherDataset(
datastore=datastore,
split="train",
ar_steps=ar_steps,
num_past_forcing_steps=num_past_forcing_steps,
num_future_forcing_steps=num_future_forcing_steps,
standardize=False,
)
else:
dataset = WeatherDataset(
datastore=datastore,
split="train",
ar_steps=ar_steps,
num_past_forcing_steps=num_past_forcing_steps,
num_future_forcing_steps=num_future_forcing_steps,
standardize=False,
)
assert len(dataset) > 0, "Dataset should contain at least one sample"
Loading