Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion jax_cosmo/angular_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jax_cosmo.power as power
import jax_cosmo.transfer as tklib
from jax_cosmo.scipy.integrate import simps
from jax_cosmo.scipy.interpolate import interp
from jax_cosmo.utils import a2z
from jax_cosmo.utils import z2a

Expand Down Expand Up @@ -78,8 +79,14 @@ def integrand(a):
# pk should have shape [na]
pk = power.nonlinear_matter_power(cosmo, k, a, transfer_fn, nonlinear_fn)

# RSD inversion

a_1 = np.clip(bkgrd.a_of_chi(cosmo, (ell + 1.5) / k), 0.00001)

# Compute the kernels for all probes
kernels = np.vstack([p.kernel(cosmo, a2z(a), ell) for p in probes])
kernels = np.vstack(
[p.kernel(cosmo, a2z(a), ell, a2z(a_1)) for p in probes]
)

# Define an ordering for the blocks of the signal vector
cl_index = np.array(_get_cl_ordering(probes))
Expand Down
15 changes: 15 additions & 0 deletions jax_cosmo/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ def __call__(self, cosmo, z):
return b * np.ones_like(z)


@register_pytree_node_class
class test_mag_bias(container):
"""
Class representing a more complex bias for magnitude biasing term, just for testing?

Parameters:
-----------
b: redshift independent bias value
"""

def __call__(self, cosmo, z):
b = self.params[0]
return 2.0 / 5.0 + b * np.sqrt(1.0 + z)


@register_pytree_node_class
class inverse_growth_linear_bias(container):
"""
Expand Down
92 changes: 85 additions & 7 deletions jax_cosmo/probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,39 @@ def integrand_single(z_prime):
return constant_factor * ell_factor * radial_kernel


@jit
def mag_kernel(cosmo, pzs, z, ell, s):
"""
Returns a magnification kernel

Needs magnification bias function
s = "logarithmic derivative of the number of sources with magnitude limit", a function valid for all z in z_prime

"""
z = np.atleast_1d(z)
zmax = max([pz.zmax for pz in pzs])
# Retrieve comoving distance corresponding to z
chi = bkgrd.radial_comoving_distance(cosmo, z2a(z))

@vmap
def integrand(z_prime):
chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime))
# Stack the dndz of all redshift bins
dndz = np.stack([pz(z_prime) for pz in pzs], axis=0)

mag_lim = (2.0 - 5.0 * s(cosmo, z_prime)) / 2.0

return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0) * mag_lim

# Computes the radial weak lensing kernel
radial_kernel = np.squeeze(simps(integrand, z, zmax, 256) * (1.0 + z) * chi)
# Constant term (maybe one too many 2.0?)
constant_factor = 3.0 * const.H0**2 * cosmo.Omega_m / 2.0 / const.c / 2.0
# Ell dependent factor
ell_factor = ell * (ell + 1)
return constant_factor * ell_factor * radial_kernel


@jit
def density_kernel(cosmo, pzs, bias, z, ell):
"""
Expand Down Expand Up @@ -131,6 +164,44 @@ def nla_kernel(cosmo, pzs, bias, z, ell):
return constant_factor * ell_factor * radial_kernel


@jit
def rsd_kernel(cosmo, pzs, z, ell, z1):
"""
Computes the RSD kernel
"""
print(z, z1)
# stack the dndz of all redshift bins
dndz = np.stack([pz(z) for pz in pzs], axis=0)

# Normalization,
constant_factor = 1.0
# Ell dependent factor
ell_factor1 = (1 + 8 * ell) / ((2 * ell + 1) ** 2.0)
# stack the dndz of all redshift bins
dndz = np.stack([pz(z) for pz in pzs], axis=0)
radial_kernel1 = (
dndz
* bkgrd.growth_rate(cosmo, z2a(z))
/ bkgrd.growth_factor(cosmo, z2a(z))
* bkgrd.H(cosmo, z2a(z))
)

# Ell dependent factor
ell_factor2 = (4) / (2 * ell + 3) * np.sqrt((2 * ell + 1) / (2 * ell + 3))
# stack the dndz of all redshift bins
dndz = np.stack([pz(z1) for pz in pzs], axis=0)
radial_kernel2 = (
dndz
* bkgrd.growth_rate(cosmo, z2a(z1))
/ bkgrd.growth_factor(cosmo, z2a(z1))
* bkgrd.H(cosmo, z2a(z1))
)

return constant_factor * (
ell_factor1 * radial_kernel1 - ell_factor2 * radial_kernel2
)


@register_pytree_node_class
class WeakLensing(container):
"""
Expand Down Expand Up @@ -187,7 +258,7 @@ def zmax(self):
pzs = self.params[0]
return max([pz.zmax for pz in pzs])

def kernel(self, cosmo, z, ell):
def kernel(self, cosmo, z, ell, z1):
"""
Compute the radial kernel for all nz bins in this probe.

Expand Down Expand Up @@ -225,23 +296,23 @@ def noise(self):
return sigma_e**2 / ngals


@register_pytree_node_class
class NumberCounts(container):
"""Class representing a galaxy clustering probe, with a bunch of bins

Parameters:
-----------
redshift_bins: nzredshift distributions

Configuration:
--------------
has_rsd....
mag_bias....
"""

def __init__(self, redshift_bins, bias, has_rsd=False, **kwargs):
def __init__(self, redshift_bins, bias, has_rsd=False, mag_bias=False, **kwargs):
super(NumberCounts, self).__init__(
redshift_bins, bias, has_rsd=has_rsd, **kwargs
redshift_bins, bias, has_rsd=has_rsd, mag_bias=mag_bias, **kwargs
)
self.mag_bias = mag_bias
self.has_rsd = has_rsd

@property
def zmax(self):
Expand All @@ -259,7 +330,7 @@ def n_tracers(self):
pzs = self.params[0]
return len(pzs)

def kernel(self, cosmo, z, ell):
def kernel(self, cosmo, z, ell, z1):
"""Compute the radial kernel for all nz bins in this probe.

Returns:
Expand All @@ -271,6 +342,13 @@ def kernel(self, cosmo, z, ell):
pzs, bias = self.params
# Retrieve density kernel
kernel = density_kernel(cosmo, pzs, bias, z, ell)

if self.mag_bias:
kernel += mag_kernel(cosmo, pzs, z, ell, self.mag_bias)

if self.has_rsd:
kernel += rsd_kernel(cosmo, pzs, z, ell, z1)

return kernel

def noise(self):
Expand Down