From 63e82ef91018bf2ba52e3ed17675f052af83252a Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Thu, 4 Jun 2026 16:31:33 +0200 Subject: [PATCH] feat(groupby): fast path for multi-key groupby([names]), flat + non-breaking `groupby(["a","b"]).sum()` previously dropped to the slow xarray fallback. Resolve a list of coordinate names (1-D, same dim) to a value frame so it rides the existing reindex fast path, then unstack the stacked result back into flat per-name dims -- byte-identical to the fallback, sparse fill cells included. The DataFrame grouper is untouched and stays compact (stacked MultiIndex over observed combinations only), so this is non-breaking. Flat dims are a dense cartesian grid, so a sparse key crossing materialises mostly-fill cells. Warn (pointing at the DataFrame grouper) when the grid is much larger than the observed combinations; the check reads the collapsed MultiIndex levels, so it is O(observed) and fires before unstack allocates. See #753; sparse-representation follow-ups tracked against #740. Co-Authored-By: Claude Opus 4.8 (1M context) --- linopy/expressions.py | 35 +++++++++++++++++ test/test_linear_expression.py | 68 ++++++++++++++++++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/linopy/expressions.py b/linopy/expressions.py index 07c2217f..09627746 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -252,6 +252,21 @@ def sum(self, use_fallback: bool = False, **kwargs: Any) -> LinearExpression: """ group = _resolve_group(self.group, self.data) + # A list of coord names rides the fast path as a value frame, then + # unstacks back to flat dims -- fast and flat, like the fallback (#753). + flatten_multikey = False + if ( + not use_fallback + and isinstance(group, (list, tuple)) + and len(group) > 1 + and all(isinstance(g, str) and g in self.data.coords for g in group) + ): + coord_dims = {self.data[g].dims for g in group} + if len(coord_dims) == 1 and len(next(iter(coord_dims))) == 1: + names = list(group) + group = self.data[names].to_dataframe()[names] + flatten_multikey = True + non_fallback_types = (pd.Series, pd.DataFrame, xr.DataArray) if isinstance(group, non_fallback_types) and not use_fallback: if isinstance(group, pd.DataFrame): @@ -297,6 +312,26 @@ def sum(self, use_fallback: bool = False, **kwargs: Any) -> LinearExpression: ds = ds.assign_coords(new_coords) ds = ds.rename({GROUP_DIM: final_group_name}) + if flatten_multikey: + # warn before unstack allocates the grid when the keys are + # sparse enough that most cells would be fill (O(observed), not O(N)) + mi = ds.indexes[final_group_name].remove_unused_levels() + observed = len(mi) + grid = int(np.prod([len(level) for level in mi.levels])) + if grid > 2 * observed and grid - observed > 10_000: + warn( + f"Grouping a LinearExpression by {names} produces a dense " + f"{grid:,}-cell grid, but only {observed:,} of those " + f"combinations occur -- the {grid - observed:,} absent ones " + f"are materialised as fill values. Group by a `pd.DataFrame` " + f"of these keys instead to keep the result compact over only " + f"the observed combinations.", + UserWarning, + stacklevel=2, + ) + ds = ds.unstack( + final_group_name, fill_value=LinearExpression._fill_value + ) return LinearExpression(ds, self.model) def func(ds: Dataset) -> Dataset: diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 3e3bff74..1282f628 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -7,6 +7,7 @@ from __future__ import annotations +import warnings from typing import Any import numpy as np @@ -1541,6 +1542,73 @@ def test_multi_key_dataarrays_unsupported( with pytest.raises(TypeError, match="unhashable"): expr.groupby([period, season]).sum() + @staticmethod + def _sparse_expr(period_vals: list, season_vals: list) -> LinearExpression: + n = len(period_vals) + s = pd.RangeIndex(n, name="s") + m = Model() + x = m.add_variables(coords=[s], name="x") + return (1.0 * x).assign_coords( + period=xr.DataArray(period_vals, dims="s", coords={"s": s}, name="period"), + season=xr.DataArray(season_vals, dims="s", coords={"s": s}, name="season"), + ) + + @pytest.mark.parametrize("spelling", [list, tuple], ids=["list", "tuple"]) + def test_multikey_fast_path_matches_fallback(self, spelling: type) -> None: + # GH #753: the fast path must equal the slow fallback, sparse cells too. + expr = self._sparse_expr([2020, 2020, 2030, 2030, 2030], list("wswws")) + group = spelling(["period", "season"]) + + fast = expr.groupby(group).sum() + slow = expr.groupby(group).sum(use_fallback=True) + + assert_linequal(fast, slow) + + def test_multikey_fast_path_is_flat_not_stacked(self) -> None: + # built via a stacked index internally, but returns flat separate dims + expr = self._sparse_expr([2020, 2020, 2030, 2030], list("wsws")) + + grouped = expr.groupby(["period", "season"]).sum() + + assert {"period", "season"} <= set(grouped.dims) + assert "group" not in grouped.dims + assert not isinstance(grouped.data.indexes.get("period"), pd.MultiIndex) + + def test_multikey_sparse_combination_is_filled(self) -> None: + # (2020, "s") never occurs -> empty term in the flat grid + expr = self._sparse_expr([2020, 2020, 2030, 2030], list("wwws")) + + grouped = expr.groupby(["period", "season"]).sum() + + cell = grouped.sel(period=2020, season="s") + assert (cell.vars == -1).all() + assert cell.coeffs.isnull().all() + + def test_multikey_dataframe_grouper_stays_compact(self) -> None: + # the DataFrame grouper keeps the stacked observed-only group dim + expr = self._sparse_expr([2020, 2020, 2030, 2030], list("wwws")) + df = expr.data[["period", "season"]].to_dataframe()[["period", "season"]] + + grouped = expr.groupby(df).sum() + + assert "group" in grouped.dims + assert isinstance(grouped.data.indexes["group"], pd.MultiIndex) + assert grouped.sizes["group"] == 3 # observed, not the 2x2=4 grid + + def test_multikey_blowup_warns_when_sparse(self) -> None: + # 200 observed combos, 200x200 grid -> nudge toward the DataFrame grouper + expr = self._sparse_expr(list(range(200)), list(range(200))) + + with pytest.warns(UserWarning, match="dense .* grid"): + expr.groupby(["period", "season"]).sum() + + def test_multikey_no_warning_when_dense(self) -> None: + expr = self._sparse_expr([2020, 2020, 2030, 2030], list("wsws")) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + expr.groupby(["period", "season"]).sum() + @pytest.mark.parametrize("use_fallback", [True, False]) @pytest.mark.parametrize( "level, values, vars_",