diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index 4ff49d6..badf7a3 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -120,7 +120,8 @@ def tree_flatten(self): def tree_unflatten(cls, aux_data, children): # Retrieve base parameters Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa = children[:8] - children = list(children[8:]).reverse() + children = list(children[8:]) + children.reverse() # We extract the remaining parameters in reverse order from how they # were inserted