Skip to content

Use jnp.interp#129

Open
ASKabalan wants to merge 15 commits intoDifferentiableUniverseInitiative:masterfrom
ASKabalan:use-jnp.interp
Open

Use jnp.interp#129
ASKabalan wants to merge 15 commits intoDifferentiableUniverseInitiative:masterfrom
ASKabalan:use-jnp.interp

Conversation

@ASKabalan
Copy link
Copy Markdown
Member

Using jax numpy interpolate instead of the custom code that was made before jnp.interp was implemented

Notebook proving that jnp.interp is more accurate and much faster

@jecampagne
Copy link
Copy Markdown
Collaborator

Although if I agree to use jax.numpy.interp as it is now implemented in the JAX lib. I was curious to see where jax-cosmo fails and to cure the problem.

Here is my 1-cent (in the context of jax-cosmo one would use the decorator and switch jnp an np due to the jax-ccsmo naming convention (ie. np would not be the numpy lib)

#@functools.partial(jax.vmap, in_axes=(0, None, None))
def interp_modif(x, xp, fp):
    """
    Simple equivalent of np.interp that compute a linear interpolation.

    We are not doing any checks, so make sure your query points are lying
    inside the array.

    x, xp, fp need to be 1d arrays
    """

    x = jnp.atleast_1d(x)

    # First we find the nearest neighbour
    ind = jnp.argmin((x - xp) ** 2)

    # Perform linear interpolation
    ind = jnp.clip(ind, 0, len(xp) - 2)

    xi = xp[ind]


    # Figure out if we are on the right or the left of nearest
    s = jnp.sign(jnp.clip(x, xp[0], xp[-2]) - xi)
    s =jax.lax.convert_element_type(s,jnp.int32)

    one = jnp.copysign(1, s)
    one = jax.lax.convert_element_type(one,jnp.int32)
    
    a = (fp[ind + one] - fp[ind]) / (
        xp[ind + one] - xp[ind]
    )
    b = fp[ind] - a * xp[ind]
    return jnp.squeeze(a * x + b)

The failure comes essentialy from the two clipping lower bounds. I have also remove the casting to int64.
You can see the result in the Google nb: https://colab.research.google.com/drive/1QhFG-G0J8Tyq9YPUuxdojvJdNEaDGoVi?usp=sharing

Comment thread jax_cosmo/scipy/interpolate.py Outdated
@ASKabalan ASKabalan mentioned this pull request Mar 2, 2026
8 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants