Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions jax_cosmo/angular_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax_cosmo.power as power
import jax_cosmo.transfer as tklib
from jax_cosmo.scipy.integrate import simps
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline
from jax_cosmo.utils import a2z, z2a


Expand Down Expand Up @@ -47,7 +48,12 @@ def find_index(a, b):


def angular_cl(
cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit
cosmo,
ell,
probes,
transfer_fn=tklib.Eisenstein_Hu,
nonlinear_fn=power.halofit,
npoints=128,
):
"""
Computes angular Cls for the provided probes
Expand Down Expand Up @@ -90,10 +96,15 @@ def combine_kernels(inds):

result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(chi**2, 1.0)

# We transpose the result just to make sure that na is first
return result.T
return result

return simps(integrand, z2a(zmax), 1.0, 512) / const.c**2
atab = np.linspace(z2a(zmax), 1.0, npoints)
eval_integral = vmap(
lambda x: np.squeeze(
InterpolatedUnivariateSpline(atab, x).integral(z2a(zmax), 1.0)
)
)
return eval_integral(integrand(atab)) / const.c**2

return cl(ell)

Expand Down
29 changes: 29 additions & 0 deletions jax_cosmo/scipy/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,32 @@ def integral(self, a, b):
sign = -1
xs = np.array([a, b])
return sign * np.diff(self.antiderivative(xs))


def splint(func, a, b, k=3, N=128):
"""Function that computes an integration with a spline function
slightly different from the original splint from scipy
"""
x = np.linspace(a, b, N)
return InterpolatedUnivariateSpline(x, func(x), k=k).integral(a, b)


Comment thread
EiffL marked this conversation as resolved.
Outdated
# def splint_fwd(func, a, b, **kwargs):
# result = splint(func, a, b, **kwargs)
# aux = (a, b, kwargs)
# return result, aux

# def splint_bwd(func, aux, grad):
# a, b, kwargs = aux

# grad_a = -grad * func(a)
# grad_b = grad * func(b)

# grad_args = []
# for i in range(len(args)):
# def _vjp_func(_t, *_args):
# return jax.grad(func, i)(_t, *_args)
# grad_args.append(grad * quad(_vjp_func, a, b, args))
# grad_args = tuple(grad_args)

# return grad_a, grad_b, grad_args
Comment thread
EiffL marked this conversation as resolved.
Outdated