Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b3c1012
Fix #311
kshirajahere Mar 2, 2026
c2d443c
Fix #311
kshirajahere Mar 2, 2026
a08eff0
Update CHANGELOG for Fix #311 and Pull #312
kshirajahere Mar 2, 2026
f3aaeca
fix: account for forecast mode forcing horizons (Closes #319)
kshirajahere Mar 3, 2026
3b1e014
fix: account for forecast mode forcing horizons (Closes #319)
kshirajahere Mar 3, 2026
7987bf4
Clarify fixes in CHANGELOG.md
kshirajahere Mar 3, 2026
1c588da
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Mar 3, 2026
ce4286e
Reduced the forecast-mode validation (#312)
kshirajahere Mar 8, 2026
10ead93
Merge remote-tracking branch 'mllam' into pr/kshirajahere/312
sadamov Mar 9, 2026
e02b422
linting
sadamov Mar 9, 2026
ed0f1a8
Fix forecast forcing horizon validation
kshirajahere Mar 17, 2026
075d92c
Resolve #312 test conflict and harden forcing handling
kshirajahere Mar 19, 2026
a1db949
Merge origin/main into fix/weatherdataset-index-bounds-311
kshirajahere Mar 19, 2026
5618bd9
Fix forecast-mode WeatherDataset length validation
kshirajahere Mar 21, 2026
4335729
Validate forecast coordinate consistency in WeatherDataset
kshirajahere Mar 21, 2026
9329f54
Merge origin/main into fix/weatherdataset-index-bounds-311
kshirajahere Mar 23, 2026
764fd7f
style: format test_datasets with black
kshirajahere Mar 24, 2026
af95a8d
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Mar 30, 2026
9d1bf17
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Apr 3, 2026
e554aa8
ci: unblock PR GPU checks
kshirajahere Apr 8, 2026
29220cf
Revert "ci: unblock PR GPU checks"
kshirajahere Apr 8, 2026
c9a0a72
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Apr 10, 2026
040ba8f
Move dataset validation out of __len__
kshirajahere Apr 10, 2026
bf161bd
Format dataset validation for pre-commit
kshirajahere Apr 10, 2026
4bacdc9
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Apr 17, 2026
b15ba68
fix: allow longer forecast forcing horizons
kshirajahere Apr 21, 2026
99fa750
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Apr 22, 2026
37947f3
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [unreleased](https://github.com/mllam/neural-lam/compare/v0.5.0...HEAD)

### Fixed
- Fix `WeatherDataset` boundary handling for out-of-range indexing and forecast-mode forcing horizon validation to prevent malformed samples [\#312](https://github.com/mllam/neural-lam/pull/312)

- Fix README image paths to use absolute GitHub URLs so images display correctly on PyPI [\#188](https://github.com/mllam/neural-lam/pull/188) @bk-simon

Expand Down
59 changes: 51 additions & 8 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,47 @@ def __len__(self):
UserWarning,
)

# check that there are enough forecast steps available to create
# samples given the number of autoregressive steps requested
n_forecast_steps = self.da_state.elapsed_forecast_duration.size
if n_forecast_steps < 2 + self.ar_steps:
# Check that there are enough forecast steps available for state
# slicing. This includes two initial states and `ar_steps` targets,
# potentially offset by past forcing.
required_state_steps = (
max(2, self.num_past_forcing_steps) + self.ar_steps
)
n_state_forecast_steps = (
self.da_state.elapsed_forecast_duration.size
)
if n_state_forecast_steps < required_state_steps:
raise ValueError(
"The number of forecast steps available "
f"({n_forecast_steps}) is less than the required "
f"2+ar_steps (2+{self.ar_steps}={2 + self.ar_steps}) for "
"creating a sample with initial and target states."
"The number of state forecast steps available "
f"({n_state_forecast_steps}) is less than the required "
f"{required_state_steps} "
f"(max(2, num_past_forcing_steps={self.num_past_forcing_steps})"
f" + ar_steps={self.ar_steps}) for creating a sample with "
"initial and target states."
)

# If forcing data is present, also validate that the complete
# forcing window can be constructed for each autoregressive target
# step without truncation.
if self.da_forcing is not None:
required_forcing_steps = (
max(2, self.num_past_forcing_steps)
+ self.ar_steps
+ self.num_future_forcing_steps
)
n_forcing_forecast_steps = (
self.da_forcing.elapsed_forecast_duration.size
)
if n_forcing_forecast_steps < required_forcing_steps:
raise ValueError(
"The number of forcing forecast steps available "
f"({n_forcing_forecast_steps}) is less than the "
f"required {required_forcing_steps} "
f"(max(2, num_past_forcing_steps={self.num_past_forcing_steps})"
f" + ar_steps={self.ar_steps} + "
f"num_future_forcing_steps={self.num_future_forcing_steps}) "
"for constructing forcing windows."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both da_state and da_forcing are always built from the same datastore time coordinates, so their sizes are guaranteed to be equal and a separate shape check is not needed. The single forecast-mode check is still necessary, but only because the required minimum size is larger when forcing is present (+ num_future_forcing_steps), not because the arrays could ever differ in size.

The no-forcing path can remain unchanged in behaviour; the with-forcing path should skip the redundant state check and goes straight to the stricter (and sufficient) forcing constraint.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed the redundant separate forcing-size check and now use the shared forecast horizon once:

  • no-forcing path keeps the original 2 + ar_steps behavior
  • with-forcing path applies the stricter minimum needed for the full forcing window

Re-ran:

  • pytest -q tests/test_datasets.py -k "dataset_length or out_of_bounds or forecast_len"
  • ruff check neural_lam/weather_dataset.py tests/test_datasets.py


return self.da_state.analysis_time.size
else:
Expand All @@ -159,6 +190,7 @@ def __len__(self):
- self.ar_steps
- max(2, self.num_past_forcing_steps)
- self.num_future_forcing_steps
+ 1
)

def _slice_state_time(self, da_state, idx, n_steps: int):
Expand Down Expand Up @@ -468,6 +500,17 @@ def __getitem__(self, idx):
the target steps.

"""
dataset_len = len(self)

# Match Python sequence semantics for negative indexing.
if idx < 0:
idx += dataset_len
if idx < 0 or idx >= dataset_len:
raise IndexError(
f"Index {idx} is out of bounds for dataset of size "
f"{dataset_len}"
)

Comment on lines +574 to +582
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was at first conflicted about this change, because it only affects the analysis type data, and only if accessed programmatically e.g. from a test. And I thought that it might rather be separate issue.

But, after consideration, I think this fits well within this PR. And knowing that there is more flexibility and robustness needed when we introduce #138 boundary datastores, this is good. Nothing to change here, just some context.

(
da_init_states,
da_target_states,
Expand Down
140 changes: 134 additions & 6 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ def _create_graph():
@pytest.mark.parametrize(
"dataset_config",
[
{"past": 0, "future": 0, "ar_steps": 1, "exp_len_reduction": 3},
{"past": 2, "future": 0, "ar_steps": 1, "exp_len_reduction": 3},
{"past": 0, "future": 2, "ar_steps": 1, "exp_len_reduction": 5},
{"past": 4, "future": 0, "ar_steps": 1, "exp_len_reduction": 5},
{"past": 0, "future": 0, "ar_steps": 5, "exp_len_reduction": 7},
{"past": 3, "future": 3, "ar_steps": 2, "exp_len_reduction": 8},
{"past": 0, "future": 0, "ar_steps": 1, "exp_len_reduction": 2},
{"past": 2, "future": 0, "ar_steps": 1, "exp_len_reduction": 2},
{"past": 0, "future": 2, "ar_steps": 1, "exp_len_reduction": 4},
{"past": 4, "future": 0, "ar_steps": 1, "exp_len_reduction": 4},
{"past": 0, "future": 0, "ar_steps": 5, "exp_len_reduction": 6},
{"past": 3, "future": 3, "ar_steps": 2, "exp_len_reduction": 7},
],
)
def test_dataset_length(dataset_config):
Expand Down Expand Up @@ -259,3 +259,131 @@ def test_dataset_length(dataset_config):
# Check that we can actually get last and first sample
dataset[0]
dataset[expected_len - 1]


def test_dataset_out_of_bounds_indexing_raises():
"""Ensure out-of-range indexing fails instead of returning bad samples."""
datastore = DummyDatastore(n_grid_points=4, n_timesteps=10)
dataset = WeatherDataset(
datastore=datastore,
split="train",
ar_steps=2,
num_past_forcing_steps=1,
num_future_forcing_steps=1,
)

# In-bounds indices work, including Python-style negative indexing.
dataset[0]
dataset[len(dataset) - 1]
dataset[-1]

# Out-of-bounds indices must fail explicitly.
with pytest.raises(IndexError):
dataset[len(dataset)]
with pytest.raises(IndexError):
dataset[len(dataset) + 1]
with pytest.raises(IndexError):
dataset[-len(dataset) - 1]


def test_forecast_len_raises_when_forcing_horizon_too_short():
from types import SimpleNamespace

import xarray as xr

dataset = WeatherDataset.__new__(WeatherDataset)
dataset.datastore = SimpleNamespace(is_forecast=True, is_ensemble=False)
dataset.ar_steps = 2
dataset.num_past_forcing_steps = 1
dataset.num_future_forcing_steps = 2

analysis_time = np.array(
["2021-01-01T00:00:00", "2021-01-01T01:00:00"],
dtype="datetime64[ns]",
)
elapsed = np.arange(5, dtype="timedelta64[h]").astype("timedelta64[ns]")

dataset.da_state = xr.DataArray(
np.zeros((2, 5, 1, 1), dtype=np.float32),
dims=(
"analysis_time",
"elapsed_forecast_duration",
"grid_index",
"state_feature",
),
coords={
"analysis_time": analysis_time,
"elapsed_forecast_duration": elapsed,
"grid_index": [0],
"state_feature": ["state_feat_0"],
},
)
dataset.da_forcing = xr.DataArray(
np.zeros((2, 5, 1, 1), dtype=np.float32),
dims=(
"analysis_time",
"elapsed_forecast_duration",
"grid_index",
"forcing_feature",
),
coords={
"analysis_time": analysis_time,
"elapsed_forecast_duration": elapsed,
"grid_index": [0],
"forcing_feature": ["forcing_feat_0"],
},
)

with pytest.raises(ValueError, match="forcing forecast steps"):
len(dataset)


def test_forecast_len_accepts_exact_forcing_horizon():
from types import SimpleNamespace

import xarray as xr

dataset = WeatherDataset.__new__(WeatherDataset)
dataset.datastore = SimpleNamespace(is_forecast=True, is_ensemble=False)
dataset.ar_steps = 2
dataset.num_past_forcing_steps = 1
dataset.num_future_forcing_steps = 2

analysis_time = np.array(
["2021-01-01T00:00:00", "2021-01-01T01:00:00"],
dtype="datetime64[ns]",
)
elapsed = np.arange(6, dtype="timedelta64[h]").astype("timedelta64[ns]")

dataset.da_state = xr.DataArray(
np.zeros((2, 6, 1, 1), dtype=np.float32),
dims=(
"analysis_time",
"elapsed_forecast_duration",
"grid_index",
"state_feature",
),
coords={
"analysis_time": analysis_time,
"elapsed_forecast_duration": elapsed,
"grid_index": [0],
"state_feature": ["state_feat_0"],
},
)
dataset.da_forcing = xr.DataArray(
np.zeros((2, 6, 1, 1), dtype=np.float32),
dims=(
"analysis_time",
"elapsed_forecast_duration",
"grid_index",
"forcing_feature",
),
coords={
"analysis_time": analysis_time,
"elapsed_forecast_duration": elapsed,
"grid_index": [0],
"forcing_feature": ["forcing_feat_0"],
},
)

assert len(dataset) == 2