Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b3c1012
Fix #311
kshirajahere Mar 2, 2026
c2d443c
Fix #311
kshirajahere Mar 2, 2026
a08eff0
Update CHANGELOG for Fix #311 and Pull #312
kshirajahere Mar 2, 2026
f3aaeca
fix: account for forecast mode forcing horizons (Closes #319)
kshirajahere Mar 3, 2026
3b1e014
fix: account for forecast mode forcing horizons (Closes #319)
kshirajahere Mar 3, 2026
7987bf4
Clarify fixes in CHANGELOG.md
kshirajahere Mar 3, 2026
1c588da
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Mar 3, 2026
ce4286e
Reduced the forecast-mode validation (#312)
kshirajahere Mar 8, 2026
10ead93
Merge remote-tracking branch 'mllam' into pr/kshirajahere/312
sadamov Mar 9, 2026
e02b422
linting
sadamov Mar 9, 2026
ed0f1a8
Fix forecast forcing horizon validation
kshirajahere Mar 17, 2026
075d92c
Resolve #312 test conflict and harden forcing handling
kshirajahere Mar 19, 2026
a1db949
Merge origin/main into fix/weatherdataset-index-bounds-311
kshirajahere Mar 19, 2026
5618bd9
Fix forecast-mode WeatherDataset length validation
kshirajahere Mar 21, 2026
4335729
Validate forecast coordinate consistency in WeatherDataset
kshirajahere Mar 21, 2026
9329f54
Merge origin/main into fix/weatherdataset-index-bounds-311
kshirajahere Mar 23, 2026
764fd7f
style: format test_datasets with black
kshirajahere Mar 24, 2026
af95a8d
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Mar 30, 2026
9d1bf17
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Apr 3, 2026
e554aa8
ci: unblock PR GPU checks
kshirajahere Apr 8, 2026
29220cf
Revert "ci: unblock PR GPU checks"
kshirajahere Apr 8, 2026
c9a0a72
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Apr 10, 2026
040ba8f
Move dataset validation out of __len__
kshirajahere Apr 10, 2026
bf161bd
Format dataset validation for pre-commit
kshirajahere Apr 10, 2026
4bacdc9
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Apr 17, 2026
b15ba68
fix: allow longer forecast forcing horizons
kshirajahere Apr 21, 2026
99fa750
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Apr 22, 2026
37947f3
Merge branch 'main' into fix/weatherdataset-index-bounds-311
kshirajahere Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ This release introduces new features including GIF animation support, wandb run
- Infer spatial coordinate names for MDPDatastore (rather than assuming names `x` and `y`), allows for e.g. lat/lon regular grids [\#169](https://github.com/mllam/neural-lam/pull/169) @leifdenby

### Fixed
- Fix `WeatherDataset` boundary handling for out-of-range indexing and forecast-mode forcing horizon validation to prevent malformed samples [\#312](https://github.com/mllam/neural-lam/pull/312)

- Fix validation crash in `plot_error_map` and resolve DDP NCCL initialization error on single-device setups [\#193](https://github.com/mllam/neural-lam/pull/193) @AdityaKumarSethia

Expand Down
176 changes: 127 additions & 49 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,6 @@ def __init__(
stacklevel=2,
)

# check that with the provided data-arrays and ar_steps that we have a
# non-zero amount of samples
if self.__len__() <= 0 and self.da_state is not None:
raise ValueError(
"The provided datastore only provides "
f"{len(self.da_state.time)} total time steps, which is too few "
"to create a single sample for the WeatherDataset "
f"configuration used in the `{split}` split. You could try "
"either reducing the number of autoregressive steps "
"(`ar_steps`) and/or the forcing window size "
"(`num_past_forcing_steps` and `num_future_forcing_steps`)"
)

# Check the dimensions and their ordering
parts = dict(state=self.da_state)
if self.da_forcing is not None:
Expand All @@ -116,6 +103,27 @@ def __init__(
"transpose the data in `BaseDatastore.get_dataarray`?"
)

self._validate_dataset_configuration()
self._dataset_len = self._compute_dataset_len()

# check that with the provided data-arrays and ar_steps that we have a
# non-zero amount of samples
if self._dataset_len <= 0 and self.da_state is not None:
time_dim = (
self.da_state.analysis_time
if self.datastore.is_forecast
else self.da_state.time
)
raise ValueError(
"The provided datastore only provides "
f"{len(time_dim)} total time steps, which is too few "
"to create a single sample for the WeatherDataset "
f"configuration used in the `{split}` split. You could try "
"either reducing the number of autoregressive steps "
"(`ar_steps`) and/or the forcing window size "
"(`num_past_forcing_steps` and `num_future_forcing_steps`)"
)

# Set up for standardization
# TODO: This will become part of ar_model.py soon!
self.standardize = standardize
Expand Down Expand Up @@ -150,56 +158,115 @@ def __init__(
else:
self.forcing_std_safe = None

def _compute_std_safe(self, std: xr.DataArray, feature: str):
eps = np.finfo(std.dtype).eps
if bool((std <= eps).any()):
logger.warning(
f"Some {feature} features have near-zero std and will be "
"standardized using machine epsilon to avoid NaN."
def _validate_dataset_configuration(self):
"""Validate dataset structure once during initialization."""
if not self.datastore.is_forecast:
return

required_state_forecast_steps = (
max(2, self.num_past_forcing_steps) + self.ar_steps
)
n_state_forecast_steps = self.da_state.elapsed_forecast_duration.size
if n_state_forecast_steps < required_state_forecast_steps:
raise ValueError(
"The number of forecast steps available "
f"({n_state_forecast_steps}) is less than the required "
f"{required_state_forecast_steps} "
f"(max(2, num_past_forcing_steps="
f"{self.num_past_forcing_steps}) + ar_steps="
f"{self.ar_steps}) for creating a sample with initial "
"and target states."
)
return std.where(std > eps, other=eps)

def __len__(self):
if self.datastore.is_forecast:
# for now we simply create a single sample for each analysis time
# and then take the first (2 + ar_steps) forecast times.
# If the datastore returns an ensemble of state realisations and
# `load_single_member=False`, each ensemble member is exposed as an
# independent sample by scaling the base dataset length below.

# check that there are enough forecast steps available to create
# samples given the number of autoregressive steps requested
n_forecast_steps = self.da_state.elapsed_forecast_duration.size
if n_forecast_steps < 2 + self.ar_steps:
raise ValueError(
"The number of forecast steps available "
f"({n_forecast_steps}) is less than the required "
f"2+ar_steps (2+{self.ar_steps}={2 + self.ar_steps}) for "
"creating a sample with initial and target states."
)
if self.da_forcing is None:
return

if not np.array_equal(
self.da_state.analysis_time.values,
self.da_forcing.analysis_time.values,
):
raise ValueError(
"State and forcing analysis times must match for "
"forecast-mode datasets."
)

if not np.array_equal(
self.da_state.elapsed_forecast_duration.values[
:required_state_forecast_steps
],
self.da_forcing.elapsed_forecast_duration.values[
:required_state_forecast_steps
],
):
raise ValueError(
"State and forcing forecast lead times must match across "
"the state forecast horizon used for target alignment in "
"forecast-mode datasets."
)

n_forcing_forecast_steps = (
self.da_forcing.elapsed_forecast_duration.size
)
required_forcing_forecast_steps = (
max(2, self.num_past_forcing_steps)
+ self.ar_steps
+ self.num_future_forcing_steps
)
if n_forcing_forecast_steps < required_forcing_forecast_steps:
raise ValueError(
"The number of forcing forecast steps available "
f"({n_forcing_forecast_steps}) is less than the required "
f"{required_forcing_forecast_steps} "
f"(max(2, num_past_forcing_steps="
f"{self.num_past_forcing_steps})"
f" + ar_steps={self.ar_steps} + "
f"num_future_forcing_steps="
f"{self.num_future_forcing_steps}) "
"for constructing forcing windows."
)

def _compute_dataset_len(self):
"""Compute dataset length without running structural validation."""
if self.datastore.is_forecast:
base_len = self.da_state.analysis_time.size
else:
# Calculate the number of samples in the dataset n_samples = total
# time steps - (autoregressive steps + past forcing + future
# forcing)
#:
# Where:
# - total time steps: len(self.da_state.time)
# - autoregressive steps: self.ar_steps
# - past forcing: max(2, self.num_past_forcing_steps) (at least 2
# time steps are required for the initial state)
# - future forcing: self.num_future_forcing_steps
base_len = (
n_state_samples = (
len(self.da_state.time)
- self.ar_steps
- max(2, self.num_past_forcing_steps)
- self.num_future_forcing_steps
+ 1
)
if self.da_forcing is None:
base_len = max(0, n_state_samples)
else:
n_forcing_samples = (
len(self.da_forcing.time)
- self.ar_steps
- max(2, self.num_past_forcing_steps)
- self.num_future_forcing_steps
+ 1
)
base_len = max(0, min(n_state_samples, n_forcing_samples))

if self.datastore.is_ensemble and not self.load_single_member:
return base_len * self.da_state.ensemble_member.size
return base_len

def _compute_std_safe(self, std: xr.DataArray, feature: str):
eps = np.finfo(std.dtype).eps
if bool((std <= eps).any()):
logger.warning(
f"Some {feature} features have near-zero std and will be "
"standardized using machine epsilon to avoid NaN."
)
return std.where(std > eps, other=eps)

def __len__(self):
if hasattr(self, "_dataset_len"):
return self._dataset_len
return self._compute_dataset_len()

def _slice_state_time(self, da_state, idx, n_steps: int):
"""
Produce a time slice of the given dataarray `da_state` (state) starting
Expand Down Expand Up @@ -502,6 +569,17 @@ def __getitem__(self, idx):
the target steps.

"""
dataset_len = self._dataset_len

# Match Python sequence semantics for negative indexing.
if idx < 0:
idx += dataset_len
if idx < 0 or idx >= dataset_len:
raise IndexError(
f"Index {idx} is out of bounds for dataset of size "
f"{dataset_len}"
)

Comment on lines +574 to +582
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was at first conflicted about this change, because it only affects the analysis type data, and only if accessed programmatically e.g. from a test. And I thought that it might rather be separate issue.

But, after consideration, I think this fits well within this PR. And knowing that there is more flexibility and robustness needed when we introduce #138 boundary datastores, this is good. Nothing to change here, just some context.

(
da_init_states,
da_target_states,
Expand Down
Loading
Loading