diff --git a/jax_cosmo/core.py b/jax_cosmo/core.py index b336d32..dbc3cc2 100644 --- a/jax_cosmo/core.py +++ b/jax_cosmo/core.py @@ -10,7 +10,19 @@ @register_pytree_node_class class Cosmology: - def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None): + def __init__( + self, + Omega_c, + Omega_b, + h, + n_s, + sigma8, + Omega_k, + w0, + wa, + Omega_nu=0.0, + gamma=None, + ): r""" Cosmology object, stores primary and derived cosmological parameters. @@ -34,6 +46,8 @@ def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None Second order term of dark energy equation of state gamma: float Index of the growth rate (optional) + Omega_nu, float + Neutrino density fraction (added support for massive neutrinos) Notes: ------ @@ -52,6 +66,7 @@ def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None self._Omega_k = Omega_k self._w0 = w0 self._wa = wa + self._Omega_nu = Omega_nu # Added Neutrino mass support self._flags = {} @@ -89,6 +104,9 @@ def __str__(self): + " \n" + " sigma8: " + str(self.sigma8) + + " \n" + + " Omega_nu: " + + str(self.Omega_nu) ) def __repr__(self): @@ -105,6 +123,7 @@ def tree_flatten(self): self._Omega_k, self._w0, self._wa, + self._Omega_nu, ) if self._flags["gamma_growth"]: @@ -118,8 +137,8 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): # Retrieve base parameters - Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa = children[:8] - children = list(children[8:]).reverse() + Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, Omega_nu = children[:9] + children = list(children[9:]).reverse() # We extract the remaining parameters in reverse order from how they # were inserted @@ -137,6 +156,7 @@ def tree_unflatten(cls, aux_data, children): Omega_k=Omega_k, w0=w0, wa=wa, + Omega_nu=Omega_nu, gamma=gamma, ) @@ -153,9 +173,14 @@ def Omega_b(self): def Omega_c(self): return self._Omega_c + @property + def Omega_nu(self): + return self._Omega_nu + @property def Omega_m(self): - return self._Omega_b + self._Omega_c + # FIX: Include neutrinos in total matter so it sums to 0.3 correctly + return self._Omega_b + self._Omega_c + self._Omega_nu @property def Omega_de(self): diff --git a/jax_cosmo/parameters.py b/jax_cosmo/parameters.py index 8ff7651..278a792 100644 --- a/jax_cosmo/parameters.py +++ b/jax_cosmo/parameters.py @@ -1,6 +1,7 @@ # This module defines a few default cosmologies from functools import partial +import jax_cosmo as jc # Assuming jc is imported as in your snippet from jax_cosmo.core import Cosmology # To add new cosmologies, we just set the parameters to some default values using @@ -12,9 +13,24 @@ Omega_c=0.2589, Omega_b=0.04860, Omega_k=0.0, + Omega_nu=0.0014, # Added: Standard minimal mass neutrinos h=0.6774, n_s=0.9667, sigma8=0.8159, w0=-1.0, wa=0.0, ) +# TO BE CHECKED +# Planck 2018 paper VI Table 1 first column (best fit) +Planck18 = partial( + Cosmology, + Omega_c=0.2607, + Omega_b=0.0490, + Omega_k=0.0, + Omega_nu=0.0014, # Added: Standard minimal mass neutrinos + h=0.6766, + n_s=0.9665, + sigma8=0.8102, + w0=-1.0, + wa=0.0, +)