Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,21 @@ def sum(self, use_fallback: bool = False, **kwargs: Any) -> LinearExpression:
"""
group = _resolve_group(self.group, self.data)

# A list of coord names rides the fast path as a value frame, then
# unstacks back to flat dims -- fast and flat, like the fallback (#753).
flatten_multikey = False
if (
not use_fallback
and isinstance(group, (list, tuple))
and len(group) > 1
and all(isinstance(g, str) and g in self.data.coords for g in group)
):
coord_dims = {self.data[g].dims for g in group}
if len(coord_dims) == 1 and len(next(iter(coord_dims))) == 1:
names = list(group)
group = self.data[names].to_dataframe()[names]
flatten_multikey = True

non_fallback_types = (pd.Series, pd.DataFrame, xr.DataArray)
if isinstance(group, non_fallback_types) and not use_fallback:
if isinstance(group, pd.DataFrame):
Expand Down Expand Up @@ -297,6 +312,26 @@ def sum(self, use_fallback: bool = False, **kwargs: Any) -> LinearExpression:
ds = ds.assign_coords(new_coords)

ds = ds.rename({GROUP_DIM: final_group_name})
if flatten_multikey:
# warn before unstack allocates the grid when the keys are
# sparse enough that most cells would be fill (O(observed), not O(N))
mi = ds.indexes[final_group_name].remove_unused_levels()
observed = len(mi)
grid = int(np.prod([len(level) for level in mi.levels]))
if grid > 2 * observed and grid - observed > 10_000:
warn(
f"Grouping a LinearExpression by {names} produces a dense "
f"{grid:,}-cell grid, but only {observed:,} of those "
f"combinations occur -- the {grid - observed:,} absent ones "
f"are materialised as fill values. Group by a `pd.DataFrame` "
f"of these keys instead to keep the result compact over only "
f"the observed combinations.",
UserWarning,
stacklevel=2,
)
ds = ds.unstack(
final_group_name, fill_value=LinearExpression._fill_value
)
return LinearExpression(ds, self.model)

def func(ds: Dataset) -> Dataset:
Expand Down
68 changes: 68 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from __future__ import annotations

import warnings
from typing import Any

import numpy as np
Expand Down Expand Up @@ -1541,6 +1542,73 @@ def test_multi_key_dataarrays_unsupported(
with pytest.raises(TypeError, match="unhashable"):
expr.groupby([period, season]).sum()

@staticmethod
def _sparse_expr(period_vals: list, season_vals: list) -> LinearExpression:
n = len(period_vals)
s = pd.RangeIndex(n, name="s")
m = Model()
x = m.add_variables(coords=[s], name="x")
return (1.0 * x).assign_coords(
period=xr.DataArray(period_vals, dims="s", coords={"s": s}, name="period"),
season=xr.DataArray(season_vals, dims="s", coords={"s": s}, name="season"),
)

@pytest.mark.parametrize("spelling", [list, tuple], ids=["list", "tuple"])
def test_multikey_fast_path_matches_fallback(self, spelling: type) -> None:
# GH #753: the fast path must equal the slow fallback, sparse cells too.
expr = self._sparse_expr([2020, 2020, 2030, 2030, 2030], list("wswws"))
group = spelling(["period", "season"])

fast = expr.groupby(group).sum()
slow = expr.groupby(group).sum(use_fallback=True)

assert_linequal(fast, slow)

def test_multikey_fast_path_is_flat_not_stacked(self) -> None:
# built via a stacked index internally, but returns flat separate dims
expr = self._sparse_expr([2020, 2020, 2030, 2030], list("wsws"))

grouped = expr.groupby(["period", "season"]).sum()

assert {"period", "season"} <= set(grouped.dims)
assert "group" not in grouped.dims
assert not isinstance(grouped.data.indexes.get("period"), pd.MultiIndex)

def test_multikey_sparse_combination_is_filled(self) -> None:
# (2020, "s") never occurs -> empty term in the flat grid
expr = self._sparse_expr([2020, 2020, 2030, 2030], list("wwws"))

grouped = expr.groupby(["period", "season"]).sum()

cell = grouped.sel(period=2020, season="s")
assert (cell.vars == -1).all()
assert cell.coeffs.isnull().all()

def test_multikey_dataframe_grouper_stays_compact(self) -> None:
# the DataFrame grouper keeps the stacked observed-only group dim
expr = self._sparse_expr([2020, 2020, 2030, 2030], list("wwws"))
df = expr.data[["period", "season"]].to_dataframe()[["period", "season"]]

grouped = expr.groupby(df).sum()

assert "group" in grouped.dims
assert isinstance(grouped.data.indexes["group"], pd.MultiIndex)
assert grouped.sizes["group"] == 3 # observed, not the 2x2=4 grid

def test_multikey_blowup_warns_when_sparse(self) -> None:
# 200 observed combos, 200x200 grid -> nudge toward the DataFrame grouper
expr = self._sparse_expr(list(range(200)), list(range(200)))

with pytest.warns(UserWarning, match="dense .* grid"):
expr.groupby(["period", "season"]).sum()

def test_multikey_no_warning_when_dense(self) -> None:
expr = self._sparse_expr([2020, 2020, 2030, 2030], list("wsws"))

with warnings.catch_warnings():
warnings.simplefilter("error")
expr.groupby(["period", "season"]).sum()

@pytest.mark.parametrize("use_fallback", [True, False])
@pytest.mark.parametrize(
"level, values, vars_",
Expand Down
Loading