Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions src/grid/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,8 @@ def func(x, y):
# x has shape (1,) and y has shape (K+1,1), output has shape (K+1,1)
x = np.array([x])
if transform:
# Transform the points back to the original domain.
orig_dom = transform.inverse(x)
dy_dx = _transform_and_rearrange_to_explicit_ode(orig_dom, y, coeffs, transform, fx)
# x here is in the finite domain; pass it directly.
dy_dx = _transform_and_rearrange_to_explicit_ode(x, y, coeffs, transform, fx)
else:
coeffs_mt = _evaluate_coeffs_on_points(x, coeffs)
dy_dx = _rearrange_to_explicit_ode(y, coeffs_mt, fx(x))
Expand All @@ -120,21 +119,21 @@ def func(x, y):
return np.vstack((*y[1:, :], dy_dx))

if transform:
# first check if the bounds are in the domain
if min(x_span) < transform.domain[0] or max(x_span) > transform.domain[1]:
# first check if the bounds are in the codomain (infinite domain r ∈ [0,∞))
if min(x_span) < transform.codomain[0] or max(x_span) > transform.codomain[1]:
raise ValueError(
f"The x_span {min(x_span), max(x_span)} is not within the transform "
f"domain {transform.domain}."
)
Comment thread
Ao-chuba marked this conversation as resolved.
# Convert the initial value problem to the new derivative space, only transform up to K-1
# e.g. the first derivative dV/dx = dV/dr * dr/dx = dV/dr / (dx/dr)
# e.g. the first derivative dV/dx = dV/dr * dr/dx = dV/dr * deriv_inverse
Comment thread
Ao-chuba marked this conversation as resolved.
deriv = _derivative_transformation_matrix(
[transform.deriv, transform.deriv2, transform.deriv3],
[transform.deriv_inverse, transform.deriv2_inverse, transform.deriv3_inverse],
x_span[0],
order - 1, # Only need derivatives up to K-1.
)
# If transform is used, then transform (x_0, x_1) that it integrates up to.
x_span = transform.transform(np.array(list(x_span)))
# If transform is used, map (r0, r1) to finite domain (x0, x1) via inverse.
x_span = transform.inverse(np.array(list(x_span)))
# Solve for derivatives in original domain by solving A(original derivs) = new derivs
y_derivs = solve(deriv, np.array(y0[1:]))
if np.any(np.isinf(y_derivs)):
Expand Down Expand Up @@ -236,9 +235,8 @@ def solve_ode_bvp(
def func(x, y):
# x has shape (N,) and y has shape (K+1, N), output has shape (K+1, N)
if transform:
# Transform the points back to the original domain.
orig_dom = transform.inverse(x)
dy_dx = _transform_and_rearrange_to_explicit_ode(orig_dom, y, coeffs, transform, fx)
# x here is in the finite domain; pass it directly.
dy_dx = _transform_and_rearrange_to_explicit_ode(x, y, coeffs, transform, fx)
else:
coeffs_mt = _evaluate_coeffs_on_points(x, coeffs)
dy_dx = _rearrange_to_explicit_ode(y, coeffs_mt, fx(x))
Expand All @@ -262,7 +260,8 @@ def bc(ya, yb):

# Solve the ODE
if transform:
pts_tf = transform.transform(x)
# Map the radial points r to finite domain x via the forward transform's inverse.
pts_tf = transform.inverse(x)
res = solve_bvp(func, bc, pts_tf, y=initial_guess_y, tol=tol, max_nodes=max_nodes)
Comment thread
Ao-chuba marked this conversation as resolved.
else:
res = solve_bvp(func, bc, x, y=initial_guess_y, tol=tol, max_nodes=max_nodes)
Expand All @@ -284,19 +283,23 @@ def _transform_solution_to_original_domain(result, tf, no_derivs, order):

# Note this is it's own function becuase it is used twice for solve_ode_ivp and bv.
def interpolate_wrt_original_var(pt):
transf_pts = tf.transform(pt)
# Map radial points r to finite domain x via the forward transform's inverse.
transf_pts = tf.inverse(pt)
# Row is which func/deriv and Col is points.
interpolated = result.sol(transf_pts)
# If derivatives are not wanted then only return y(x).
if no_derivs:
if interpolated.ndim == 1:
return interpolated
return interpolated[0, :]
deriv_funcs = [tf.deriv, tf.deriv2, tf.deriv3]
# Use inverse derivative methods from BaseTransform (added in #307).
# The ODE is solved in the transformed domain x; converting back to r requires
# dx/dr derivatives, i.e. the derivatives of the inverse transformation.
deriv_funcs = [tf.deriv_inverse, tf.deriv2_inverse, tf.deriv3_inverse]
new_interpolate = np.zeros(interpolated.shape)
new_interpolate[0, :] = interpolated[0, :]
for i in range(interpolated.shape[1]):
# Calculate the jacobian dr/dx of the original domain.
# Calculate the jacobian dx/dr of the original domain.
deriv = _derivative_transformation_matrix(deriv_funcs, pt[i], order - 1)
new_interpolate[1:, i] = deriv.dot(interpolated[1:, i])
return new_interpolate
Expand Down Expand Up @@ -407,8 +410,12 @@ def _transform_ode_from_rtransform(coeff_a: list | np.ndarray, tf: BaseTransform
Coefficients :math:`b_j(r)` of the new ODE with respect to transformed variable :math:`r`.

"""
deriv_func = [tf.deriv, tf.deriv2, tf.deriv3]
return _transform_ode_from_derivs(coeff_a, deriv_func, x)
# x here is the finite domain points of tf.
# We want to transform the ODE from the infinite domain r = tf(x) to the finite domain x.
r = tf.transform(x)
# The derivatives we need are the derivatives of the inverse transform (dr -> dx), i.e. tf.deriv_inverse
deriv_func = [tf.deriv_inverse, tf.deriv2_inverse, tf.deriv3_inverse]
return _transform_ode_from_derivs(coeff_a, deriv_func, r)


def _transform_and_rearrange_to_explicit_ode(
Expand Down Expand Up @@ -442,8 +449,9 @@ def _transform_and_rearrange_to_explicit_ode(
to explicit form evaluated on all N points.

"""
orig_dom = tf.transform(x)
coeff_b = _transform_ode_from_rtransform(coeff_a, tf, x)
result = _rearrange_to_explicit_ode(y, coeff_b, fx_func(x))
result = _rearrange_to_explicit_ode(y, coeff_b, fx_func(orig_dom))
return result


Expand Down
35 changes: 25 additions & 10 deletions src/grid/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ def solve_poisson_ivp(
r_interval: tuple = (1000, 1e-5),
ode_params: dict | type(None) = None,
):
from grid.rtransform import InverseRTransform

if isinstance(transform, InverseRTransform):
transform = transform._tfm


r"""
Comment thread
Ao-chuba marked this conversation as resolved.
Outdated
Return interpolation of the solution to the Poisson equation solved as an initial value problem.

Expand All @@ -194,9 +200,11 @@ def solve_poisson_ivp(
spherical harmonic basis.
func_vals : ndarray(N,)
The function values evaluated on all :math:`N` points on the molecular grid.
transform : BaseTransform, optional
Transformation from infinite domain :math:`r \in [0, \infty)` to another
domain that is a finite.
transform : BaseTransform
Forward transformation from the finite domain :math:`x` to the infinite radial
domain :math:`r \in [0, \infty)`, i.e. the same transform used to construct the
atomic grid. Inverse-related quantities are computed internally via
``BaseTransform.deriv_inverse`` and related methods.
r_interval : tuple, optional
The interval :math:`(b, a)` of :math:`r` for which the ODE solver will start from and end,
where :math:`b>a`\. The value :math:`b` should be large as it determines the asymptotic
Expand Down Expand Up @@ -244,10 +252,10 @@ def _solve_poisson_bvp_atomgrid(
sph_o_l = generate_real_spherical_harmonics(0, np.array([0.1]), np.array([0.1]))
boundary = atomgrid.integrate(func_vals) / sph_o_l[0, 0]

# Check if the domain of transform is in [0, \infty)
domain = transform.domain
if domain[0] < 0.0:
raise ValueError(f"The domain of the transform {domain} should be in [0, infinity).")
# Check if the codomain of transform is in [0, \infty)
codomain = transform.codomain
if codomain[0] < 0.0:
raise ValueError(f"The codomain of the transform {codomain} should be in [0, infinity).")

# Get the radial components from expanding func into real spherical harmonics.
radial_components = atomgrid.radial_component_splines(func_vals)
Expand Down Expand Up @@ -323,6 +331,11 @@ def solve_poisson_bvp(
remove_large_pts: float = 1e6,
ode_params: dict | type(None) = None,
):
from grid.rtransform import InverseRTransform

if isinstance(transform, InverseRTransform):
transform = transform._tfm

r"""
Return interpolation of the solution to the Poisson equation solved as a boundary value problem.
Comment thread
Ao-chuba marked this conversation as resolved.
Outdated

Expand All @@ -346,9 +359,11 @@ def solve_poisson_bvp(
harmonic basis.
func_vals : ndarray(N,)
The function values evaluated on all :math:`N` points on the molecular grid.
transform : BaseTransform, optional
Transformation from infinite domain :math:`r \in [0, \infty)` to another
domain that is a finite.
transform : BaseTransform
Forward transformation from the finite domain :math:`x` to the infinite radial
domain :math:`r \in [0, \infty)`, i.e. the same transform used to construct the
atomic grid. Inverse-related quantities are computed internally via
``BaseTransform.deriv_inverse`` and related methods.
boundary : float, optional
The boundary value of :math:`g` in the limit of r to infinity.
include_origin : bool, optional
Expand Down
16 changes: 15 additions & 1 deletion src/grid/rtransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,14 @@ def inverse(self, r: np.ndarray):


class InverseRTransform(BaseTransform):
"""Inverse transformation class for any general transformation."""
"""Inverse transformation class for any general transformation.

.. deprecated::
``InverseRTransform`` is deprecated and will be removed in a future release.
The inverse transformation derivatives (``deriv_inverse``, ``deriv2_inverse``,
``deriv3_inverse``) are now available directly on every ``BaseTransform`` subclass.
Pass the forward transform and call those methods instead.
"""

def __init__(self, transform: BaseTransform):
"""Construct InverseRTransform instance.
Expand All @@ -526,6 +533,13 @@ def __init__(self, transform: BaseTransform):
One-dimension transformation instance.

"""
warnings.warn(
"InverseRTransform is deprecated and will be removed in a future release. "
"Use the forward transform directly and call deriv_inverse(), deriv2_inverse(), "
"or deriv3_inverse() on it instead.",
DeprecationWarning,
stacklevel=2,
)
if not isinstance(transform, BaseTransform):
raise TypeError(f"Input need to be a transform instance, got {type(transform)}.")
self._tfm = transform
Expand Down
Loading
Loading