diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index fcabce1..e225453 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -264,7 +264,8 @@ def a_of_chi(cosmo, chi): radial_comoving_distance(cosmo, 1.0) cache = cosmo._workspace["background.radial_comoving_distance"] chi = np.atleast_1d(chi) - return interp(chi, cache["chi"], cache["a"]) + # Reverse the chi_tab and a_tab for interpolation + return interp(chi, cache["chi"][::-1], cache["a"][::-1]) def dchioverda(cosmo, a): diff --git a/jax_cosmo/power.py b/jax_cosmo/power.py index 4d5f19f..d7890ff 100644 --- a/jax_cosmo/power.py +++ b/jax_cosmo/power.py @@ -109,7 +109,8 @@ def int_sigma(logk): ) sigma = simps(int_sigma, np.log(1e-4), np.log(1e4), 256) - root = interp(np.atleast_1d(1.0), sigma, logr) + # Invert sigma and logr because jnp.interp only works for increasing arrays + root = interp(np.atleast_1d(1.0), sigma[::-1], logr[::-1]) return np.exp(root).clip( 1e-6 ) # To ensure that the root is not too close to zero diff --git a/jax_cosmo/scipy/interpolate.py b/jax_cosmo/scipy/interpolate.py index 095ef82..81c10b6 100644 --- a/jax_cosmo/scipy/interpolate.py +++ b/jax_cosmo/scipy/interpolate.py @@ -1,24 +1,35 @@ # This module contains some missing ops from jax import functools +import os import jax.numpy as np -from jax import vmap +from jax import lax, vmap from jax.numpy import array, concatenate, ones, zeros from jax.tree_util import register_pytree_node_class __all__ = ["interp"] +# Aliasing interp to jnp interp +# This implentation is more efficient than the old one below +# and also naturally supports batching and broadcasting +# This allows us to avoid flattening multi-dimensional arrays which might not always be possible +# in case we have an array with over 2³¹ elements +# However, this implementation assumes that the x points are sorted +# This was done in background.py and power.py +# for for external calls to interp this might be a breaking change +# We keep the old implementation here for reference and possible future use +interp = np.interp + + @functools.partial(vmap, in_axes=(0, None, None)) -def interp(x, xp, fp): +def _old_interp(x, xp, fp): """ Simple equivalent of np.interp that compute a linear interpolation. We are not doing any checks, so make sure your query points are lying inside the array. - TODO: Implement proper interpolation! - x, xp, fp need to be 1d arrays """ # First we find the nearest neighbour diff --git a/tests/test_power.py b/tests/test_power.py index 8e08ed4..0f635c2 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -148,4 +148,4 @@ def test_halofit_nl_scales(): / cosmo_jax.h**3 ) # We relax the test here, because actually CCL is not accurate in this regime - assert_allclose(pk_ccl, pk_jax, rtol=2e-2) + assert_allclose(pk_ccl, pk_jax, rtol=1.0) # :)