Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions jax_cosmo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,19 @@

@register_pytree_node_class
class Cosmology:
def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None):
def __init__(
self,
Omega_c,
Omega_b,
h,
n_s,
sigma8,
Omega_k,
w0,
wa,
Omega_nu=0.0,
gamma=None,
):
r"""
Cosmology object, stores primary and derived cosmological parameters.

Expand All @@ -34,6 +46,8 @@ def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None
Second order term of dark energy equation of state
gamma: float
Index of the growth rate (optional)
Omega_nu, float
Neutrino density fraction (added support for massive neutrinos)

Notes:
------
Expand All @@ -52,6 +66,7 @@ def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None
self._Omega_k = Omega_k
self._w0 = w0
self._wa = wa
self._Omega_nu = Omega_nu # Added Neutrino mass support

self._flags = {}

Expand Down Expand Up @@ -89,6 +104,9 @@ def __str__(self):
+ " \n"
+ " sigma8: "
+ str(self.sigma8)
+ " \n"
+ " Omega_nu: "
+ str(self.Omega_nu)
)

def __repr__(self):
Expand All @@ -105,6 +123,7 @@ def tree_flatten(self):
self._Omega_k,
self._w0,
self._wa,
self._Omega_nu,
)

if self._flags["gamma_growth"]:
Expand All @@ -118,8 +137,8 @@ def tree_flatten(self):
@classmethod
def tree_unflatten(cls, aux_data, children):
# Retrieve base parameters
Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa = children[:8]
children = list(children[8:]).reverse()
Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, Omega_nu = children[:9]
children = list(children[9:]).reverse()

# We extract the remaining parameters in reverse order from how they
# were inserted
Expand All @@ -137,6 +156,7 @@ def tree_unflatten(cls, aux_data, children):
Omega_k=Omega_k,
w0=w0,
wa=wa,
Omega_nu=Omega_nu,
gamma=gamma,
)

Expand All @@ -153,9 +173,14 @@ def Omega_b(self):
def Omega_c(self):
return self._Omega_c

@property
def Omega_nu(self):
return self._Omega_nu

@property
def Omega_m(self):
return self._Omega_b + self._Omega_c
# FIX: Include neutrinos in total matter so it sums to 0.3 correctly
return self._Omega_b + self._Omega_c + self._Omega_nu

@property
def Omega_de(self):
Expand Down
16 changes: 16 additions & 0 deletions jax_cosmo/parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This module defines a few default cosmologies
from functools import partial

import jax_cosmo as jc # Assuming jc is imported as in your snippet
from jax_cosmo.core import Cosmology

# To add new cosmologies, we just set the parameters to some default values using
Expand All @@ -12,9 +13,24 @@
Omega_c=0.2589,
Omega_b=0.04860,
Omega_k=0.0,
Omega_nu=0.0014, # Added: Standard minimal mass neutrinos
h=0.6774,
n_s=0.9667,
sigma8=0.8159,
w0=-1.0,
wa=0.0,
)
# TO BE CHECKED
# Planck 2018 paper VI Table 1 first column (best fit)
Planck18 = partial(
Cosmology,
Omega_c=0.2607,
Omega_b=0.0490,
Omega_k=0.0,
Omega_nu=0.0014, # Added: Standard minimal mass neutrinos
h=0.6766,
n_s=0.9665,
sigma8=0.8102,
w0=-1.0,
wa=0.0,
)
Loading