diff --git a/jax_cosmo/angular_cl.py b/jax_cosmo/angular_cl.py index b842076..558ba2d 100644 --- a/jax_cosmo/angular_cl.py +++ b/jax_cosmo/angular_cl.py @@ -11,6 +11,7 @@ import jax_cosmo.power as power import jax_cosmo.transfer as tklib from jax_cosmo.scipy.integrate import simps +from jax_cosmo.scipy.interpolate import interp from jax_cosmo.utils import a2z from jax_cosmo.utils import z2a @@ -78,8 +79,14 @@ def integrand(a): # pk should have shape [na] pk = power.nonlinear_matter_power(cosmo, k, a, transfer_fn, nonlinear_fn) + # RSD inversion + + a_1 = np.clip(bkgrd.a_of_chi(cosmo, (ell + 1.5) / k), 0.00001) + # Compute the kernels for all probes - kernels = np.vstack([p.kernel(cosmo, a2z(a), ell) for p in probes]) + kernels = np.vstack( + [p.kernel(cosmo, a2z(a), ell, a2z(a_1)) for p in probes] + ) # Define an ordering for the blocks of the signal vector cl_index = np.array(_get_cl_ordering(probes)) diff --git a/jax_cosmo/bias.py b/jax_cosmo/bias.py index 6887ce8..015498f 100644 --- a/jax_cosmo/bias.py +++ b/jax_cosmo/bias.py @@ -23,6 +23,21 @@ def __call__(self, cosmo, z): return b * np.ones_like(z) +@register_pytree_node_class +class test_mag_bias(container): + """ + Class representing a more complex bias for magnitude biasing term, just for testing? + + Parameters: + ----------- + b: redshift independent bias value + """ + + def __call__(self, cosmo, z): + b = self.params[0] + return 2.0 / 5.0 + b * np.sqrt(1.0 + z) + + @register_pytree_node_class class inverse_growth_linear_bias(container): """ diff --git a/jax_cosmo/probes.py b/jax_cosmo/probes.py index fab6880..9d8ec93 100644 --- a/jax_cosmo/probes.py +++ b/jax_cosmo/probes.py @@ -76,6 +76,39 @@ def integrand_single(z_prime): return constant_factor * ell_factor * radial_kernel +@jit +def mag_kernel(cosmo, pzs, z, ell, s): + """ + Returns a magnification kernel + + Needs magnification bias function + s = "logarithmic derivative of the number of sources with magnitude limit", a function valid for all z in z_prime + + """ + z = np.atleast_1d(z) + zmax = max([pz.zmax for pz in pzs]) + # Retrieve comoving distance corresponding to z + chi = bkgrd.radial_comoving_distance(cosmo, z2a(z)) + + @vmap + def integrand(z_prime): + chi_prime = bkgrd.radial_comoving_distance(cosmo, z2a(z_prime)) + # Stack the dndz of all redshift bins + dndz = np.stack([pz(z_prime) for pz in pzs], axis=0) + + mag_lim = (2.0 - 5.0 * s(cosmo, z_prime)) / 2.0 + + return dndz * np.clip(chi_prime - chi, 0) / np.clip(chi_prime, 1.0) * mag_lim + + # Computes the radial weak lensing kernel + radial_kernel = np.squeeze(simps(integrand, z, zmax, 256) * (1.0 + z) * chi) + # Constant term (maybe one too many 2.0?) + constant_factor = 3.0 * const.H0**2 * cosmo.Omega_m / 2.0 / const.c / 2.0 + # Ell dependent factor + ell_factor = ell * (ell + 1) + return constant_factor * ell_factor * radial_kernel + + @jit def density_kernel(cosmo, pzs, bias, z, ell): """ @@ -131,6 +164,44 @@ def nla_kernel(cosmo, pzs, bias, z, ell): return constant_factor * ell_factor * radial_kernel +@jit +def rsd_kernel(cosmo, pzs, z, ell, z1): + """ + Computes the RSD kernel + """ + print(z, z1) + # stack the dndz of all redshift bins + dndz = np.stack([pz(z) for pz in pzs], axis=0) + + # Normalization, + constant_factor = 1.0 + # Ell dependent factor + ell_factor1 = (1 + 8 * ell) / ((2 * ell + 1) ** 2.0) + # stack the dndz of all redshift bins + dndz = np.stack([pz(z) for pz in pzs], axis=0) + radial_kernel1 = ( + dndz + * bkgrd.growth_rate(cosmo, z2a(z)) + / bkgrd.growth_factor(cosmo, z2a(z)) + * bkgrd.H(cosmo, z2a(z)) + ) + + # Ell dependent factor + ell_factor2 = (4) / (2 * ell + 3) * np.sqrt((2 * ell + 1) / (2 * ell + 3)) + # stack the dndz of all redshift bins + dndz = np.stack([pz(z1) for pz in pzs], axis=0) + radial_kernel2 = ( + dndz + * bkgrd.growth_rate(cosmo, z2a(z1)) + / bkgrd.growth_factor(cosmo, z2a(z1)) + * bkgrd.H(cosmo, z2a(z1)) + ) + + return constant_factor * ( + ell_factor1 * radial_kernel1 - ell_factor2 * radial_kernel2 + ) + + @register_pytree_node_class class WeakLensing(container): """ @@ -187,7 +258,7 @@ def zmax(self): pzs = self.params[0] return max([pz.zmax for pz in pzs]) - def kernel(self, cosmo, z, ell): + def kernel(self, cosmo, z, ell, z1): """ Compute the radial kernel for all nz bins in this probe. @@ -225,23 +296,23 @@ def noise(self): return sigma_e**2 / ngals -@register_pytree_node_class class NumberCounts(container): """Class representing a galaxy clustering probe, with a bunch of bins - Parameters: ----------- redshift_bins: nzredshift distributions - Configuration: -------------- has_rsd.... + mag_bias.... """ - def __init__(self, redshift_bins, bias, has_rsd=False, **kwargs): + def __init__(self, redshift_bins, bias, has_rsd=False, mag_bias=False, **kwargs): super(NumberCounts, self).__init__( - redshift_bins, bias, has_rsd=has_rsd, **kwargs + redshift_bins, bias, has_rsd=has_rsd, mag_bias=mag_bias, **kwargs ) + self.mag_bias = mag_bias + self.has_rsd = has_rsd @property def zmax(self): @@ -259,7 +330,7 @@ def n_tracers(self): pzs = self.params[0] return len(pzs) - def kernel(self, cosmo, z, ell): + def kernel(self, cosmo, z, ell, z1): """Compute the radial kernel for all nz bins in this probe. Returns: @@ -271,6 +342,13 @@ def kernel(self, cosmo, z, ell): pzs, bias = self.params # Retrieve density kernel kernel = density_kernel(cosmo, pzs, bias, z, ell) + + if self.mag_bias: + kernel += mag_kernel(cosmo, pzs, z, ell, self.mag_bias) + + if self.has_rsd: + kernel += rsd_kernel(cosmo, pzs, z, ell, z1) + return kernel def noise(self):