Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions jax_cosmo/angular_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import jax_cosmo.background as bkgrd
import jax_cosmo.constants as const
import jax_cosmo.power as power
import jax_cosmo.transfer as tklib
from jax_cosmo.scipy.integrate import simps
from jax_cosmo.utils import a2z
from jax_cosmo.utils import z2a
Expand Down Expand Up @@ -54,7 +53,7 @@ def find_index(a, b):


def angular_cl(
cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit
cosmo, ell, probes, transfer_fn=None, nonlinear_fn=None
):
"""
Computes angular Cls for the provided probes
Expand All @@ -66,6 +65,13 @@ def angular_cl(

cls: [ell, ncls]
"""
# Use transfer_fn and nonlinear_fn if provided, else fallback to what is
# defined in cosmo
if transfer_fn is None:
transfer_fn = cosmo.transfer_fn
if nonlinear_fn is None:
nonlinear_fn = cosmo.nonlinear_fn

# Retrieve the maximum redshift probed
zmax = max([p.zmax for p in probes])

Expand Down Expand Up @@ -172,8 +178,8 @@ def gaussian_cl_covariance_and_mean(
cosmo,
ell,
probes,
transfer_fn=tklib.Eisenstein_Hu,
nonlinear_fn=power.halofit,
transfer_fn=None,
nonlinear_fn=None,
f_sky=0.25,
sparse=False,
):
Expand All @@ -186,6 +192,13 @@ def gaussian_cl_covariance_and_mean(

return_cls: (returns signal + noise cl, covariance)
"""
# Use transfer_fn and nonlinear_fn if provided, else fallback to what is
# defined in cosmo
if transfer_fn is None:
transfer_fn = cosmo.transfer_fn
if nonlinear_fn is None:
nonlinear_fn = cosmo.nonlinear_fn

ell = np.atleast_1d(ell)
n_ell = len(ell)

Expand Down
33 changes: 30 additions & 3 deletions jax_cosmo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from jax.tree_util import register_pytree_node_class

import jax_cosmo.constants as const
import jax_cosmo.power as power
import jax_cosmo.transfer as tklib
from jax_cosmo.utils import a2z
from jax_cosmo.utils import z2a

Expand All @@ -15,7 +17,8 @@

@register_pytree_node_class
class Cosmology:
def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None):
def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None,
transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit):
"""
Cosmology object, stores primary and derived cosmological parameters.

Expand All @@ -37,8 +40,12 @@ def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None
First order term of dark energy equation
wa, float
Second order term of dark energy equation of state
gamma: float
gamma: float, optional
Index of the growth rate (optional)
transfer_fn: transfer_fn(cosmo, k, **kwargs), optional
Transfer function.
nonlinear_fn: nonlinear_fn(cosmo, k, **kwargs), optional
Non-linear matter power spectrum function.

Notes:
------
Expand All @@ -63,6 +70,8 @@ def __init__(self, Omega_c, Omega_b, h, n_s, sigma8, Omega_k, w0, wa, gamma=None
# Secondary optional parameters
self._gamma = gamma
self._flags["gamma_growth"] = gamma is not None
self._flags["config"] = {"transfer_fn" : transfer_fn,
"nonlinear_fn" : nonlinear_fn}

# Create a workspace where functions can store some precomputed
# results
Expand Down Expand Up @@ -143,6 +152,7 @@ def tree_unflatten(cls, aux_data, children):
w0=w0,
wa=wa,
gamma=gamma,
**aux_data["config"]
)

# Cosmological parameters, base and derived
Expand Down Expand Up @@ -181,7 +191,7 @@ def k(self):
return k

@property
def sqrtk(self):
def _sqrtk(self):
return np.sqrt(np.abs(self._Omega_k))

@property
Expand Down Expand Up @@ -211,3 +221,20 @@ def sigma8(self):
@property
def gamma(self):
return self._gamma

# Options
@property
def transfer_fn(self):
return self._flags["config"]["transfer_fn"]

@transfer_fn.setter
def transfer_fn(self, value):
self._flags["config"]["transfer_fn"] = value

@property
def nonlinear_fn(self):
return self._flags["config"]["nonlinear_fn"]

@nonlinear_fn.setter
def nonlinear_fn(self, value):
self._flags["config"]["nonlinear_fn"] = value
5 changes: 5 additions & 0 deletions jax_cosmo/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial

from jax_cosmo.core import Cosmology
from jax_cosmo.power import halofit

# To add new cosmologies, we just set the parameters to some default values using
# partial
Expand All @@ -22,3 +23,7 @@
w0=-1.0,
wa=0.0,
)

# Shortcuts for the different halofit implementations
halofit_smith2003 = partial(halofit, prescription="smith2003")
halofit_takahashi2012 = partial(halofit, prescription="takahashi2012")
22 changes: 17 additions & 5 deletions jax_cosmo/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import jax_cosmo.background as bkgrd
import jax_cosmo.constants as const
import jax_cosmo.transfer as tklib
from jax_cosmo.scipy.integrate import romb
from jax_cosmo.scipy.integrate import simps
from jax_cosmo.scipy.interpolate import interp
Expand All @@ -23,7 +22,7 @@ def primordial_matter_power(cosmo, k):
return k ** cosmo.n_s


def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwargs):
def linear_matter_power(cosmo, k, a=1.0, transfer_fn=None, **kwargs):
r""" Computes the linear matter power spectrum.

Parameters
Expand All @@ -34,8 +33,8 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar
a: array_like, optional
Scale factor (def: 1.0)

transfer_fn: transfer_fn(cosmo, k, **kwargs)
Transfer function
transfer_fn: transfer_fn(cosmo, k, **kwargs), optional
Transfer function, if None uses cosmo.transfer_fn

Returns
-------
Expand All @@ -44,6 +43,10 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar
and scale factor.

"""
# Use transfer_fn if provided, else fallback to what is defined in cosmo
if transfer_fn is None:
transfer_fn = cosmo.transfer_fn

k = np.atleast_1d(k)
a = np.atleast_1d(a)
g = bkgrd.growth_factor(cosmo, a)
Expand Down Expand Up @@ -154,6 +157,9 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):
a: array_like, optional
Scale factor (def: 1.0)

transfer_fn: transfer_fn(cosmo, k, **kwargs), optional
Transfer function, if None uses cosmo.transfer_fn

prescription: str, optional
Either 'smith2003' or 'takahashi2012'

Expand Down Expand Up @@ -269,10 +275,16 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):


def nonlinear_matter_power(
cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=halofit
cosmo, k, a=1.0, transfer_fn=None, nonlinear_fn=None
):
""" Computes the non-linear matter power spectrum.

This function is just a wrapper over several nonlinear power spectra.
"""
# Use transfer_fn and nonlinear_fn if provided, else fallback to what is
# defined in cosmo
if transfer_fn is None:
transfer_fn = cosmo.transfer_fn
if nonlinear_fn is None:
nonlinear_fn = cosmo.nonlinear_fn
return nonlinear_fn(cosmo, k, a, transfer_fn=transfer_fn)