Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions jax_cosmo/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
"angular_diameter_distance",
"growth_factor",
"growth_rate",
"luminosity_distance",
"distance_modulus",
]


Expand Down Expand Up @@ -580,3 +582,43 @@ def _growth_rate_gamma(cosmo, a):
see :cite:`2019:Euclid Preparation VII, eqn.32`
"""
return Omega_m_a(cosmo, a) ** cosmo.gamma


def luminosity_distance(cosmo, a):
"""
Compute the luminosity distance in [Mpc] for a given scale factor a
Parameters
----------
cosmo : `Cosmology'
Cosmology object

a : array_like
Scale factor

Returns
-------
d_L : ndarray, or float if input scalar
Luminosity distance corresponding to the requested scale factor
"""

return transverse_comoving_distance(cosmo, a) / a


def distance_modulus(cosmo, a):
"""
Compute the distance modulus for a given scale factor a

Parameters
----------
cosmo : `Cosmology'
Cosmology object

a : array_like
Scale factor

Returns
-------
mu : Distance modulus corresponding to scale factor a
"""

return 5 * np.log10(luminosity_distance(cosmo, a) * 1e5)
63 changes: 62 additions & 1 deletion tests/test_background.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jax.numpy as jnp
import numpy as np
import pyccl as ccl
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -213,3 +212,65 @@ def test_growth_gamma():
gjax = bkgrd.growth_factor(cosmo_jax, a)

assert_allclose(gccl, gjax, rtol=1e-2)


def test_luminosity_distance():
cosmo_ccl = ccl.Cosmology(
Omega_c=0.3,
Omega_b=0.05,
h=0.7,
sigma8=0.8,
n_s=0.96,
Neff=0,
transfer_function="eisenstein_hu",
matter_power_spectrum="linear",
)

cosmo_jax = Cosmology(
Omega_c=0.3,
Omega_b=0.05,
h=0.7,
sigma8=0.8,
n_s=0.96,
Omega_k=0.0,
w0=-1.0,
wa=0.0,
)

# Test array of scale factors
a = np.linspace(0.01, 1.0)

dl_ccl = ccl.luminosity_distance(cosmo_ccl, a)
dl_jax = bkgrd.luminosity_distance(cosmo_jax, a) / cosmo_jax.h
assert_allclose(dl_ccl, dl_jax, rtol=0.5e-2)


def test_distance_modulus():
cosmo_ccl = ccl.Cosmology(
Omega_c=0.3,
Omega_b=0.05,
h=0.7,
sigma8=0.8,
n_s=0.96,
Neff=0,
transfer_function="eisenstein_hu",
matter_power_spectrum="linear",
)

cosmo_jax = Cosmology(
Omega_c=0.3,
Omega_b=0.05,
h=0.7,
sigma8=0.8,
n_s=0.96,
Omega_k=0.0,
w0=-1.0,
wa=0.0,
)

# Test array of scale factors
a = np.linspace(0.01, 0.99)

mu_ccl = ccl.distance_modulus(cosmo_ccl, a)
mu_jax = bkgrd.distance_modulus(cosmo_jax, a) - 5*np.log10(cosmo_jax.h)
assert_allclose(mu_ccl, mu_jax, rtol=0.5e-2)