Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 1 addition & 28 deletions jax_cosmo/scipy/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,7 @@

__all__ = ["interp"]


@functools.partial(vmap, in_axes=(0, None, None))
def interp(x, xp, fp):
"""
Simple equivalent of np.interp that compute a linear interpolation.

We are not doing any checks, so make sure your query points are lying
inside the array.

TODO: Implement proper interpolation!

x, xp, fp need to be 1d arrays
"""
# First we find the nearest neighbour
ind = np.argmin((x - xp) ** 2)

# Perform linear interpolation
ind = np.clip(ind, 1, len(xp) - 2)

xi = xp[ind]
# Figure out if we are on the right or the left of nearest
s = np.sign(np.clip(x, xp[1], xp[-2]) - xi).astype(np.int64)
a = (fp[ind + np.copysign(1, s).astype(np.int64)] - fp[ind]) / (
xp[ind + np.copysign(1, s).astype(np.int64)] - xp[ind]
)
b = fp[ind] - a * xp[ind]
return a * x + b

interp = np.intern
Comment thread
EiffL marked this conversation as resolved.
Outdated

@register_pytree_node_class
class InterpolatedUnivariateSpline(object):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_spline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# This module tests the InterpolatedUnivariateSpline implementation against
# SciPy
from jax.config import config
from jax import config

config.update("jax_enable_x64", True)
import jax.numpy as np
Expand Down