Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions src/tinygp/kernels/quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,20 +220,41 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:


class Sum(Quasisep):
"""A helper to represent the sum of two quasiseparable kernels"""
"""A helper to represent the sum of two quasiseparable kernels

Args:
kernel1: The first kernel.
kernel2: The second kernel.
use_block: If ``True`` (default), use :class:`Block` diagonal matrices
for the transition matrices, design matrices, and stationary
covariance. If ``False``, use dense ``block_diag`` representations
instead, which avoids compatibility issues with some operations
(e.g. banded noise, product kernels) at a small performance cost
for the state-space matrices.
"""

kernel1: Quasisep
kernel2: Quasisep
use_block: bool = eqx.field(static=True, default=True)

def coord_to_sortable(self, X: JAXArray) -> JAXArray:
"""We assume that both kernels use the same coordinates"""
return self.kernel1.coord_to_sortable(X)

def _block_or_dense(self, m1: JAXArray, m2: JAXArray) -> JAXArray:
if self.use_block:
return Block(m1, m2)
from jax.scipy.linalg import block_diag as jsp_block_diag
Comment thread
dfm marked this conversation as resolved.
Outdated

return jsp_block_diag(m1, m2)

def design_matrix(self) -> JAXArray:
return Block(self.kernel1.design_matrix(), self.kernel2.design_matrix())
return self._block_or_dense(
self.kernel1.design_matrix(), self.kernel2.design_matrix()
)

def stationary_covariance(self) -> JAXArray:
return Block(
return self._block_or_dense(
self.kernel1.stationary_covariance(),
self.kernel2.stationary_covariance(),
)
Expand All @@ -247,7 +268,7 @@ def observation_model(self, X: JAXArray) -> JAXArray:
)

def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
return Block(
return self._block_or_dense(
self.kernel1.transition_matrix(X1, X2),
self.kernel2.transition_matrix(X1, X2),
)
Expand Down Expand Up @@ -632,6 +653,10 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:


def _prod_helper(a1: JAXArray, a2: JAXArray) -> JAXArray:
if isinstance(a1, Block):
a1 = a1.to_dense()
if isinstance(a2, Block):
a2 = a2.to_dense()
i, j = np.meshgrid(np.arange(a1.shape[0]), np.arange(a2.shape[0]))
i = i.flatten()
j = j.flatten()
Expand Down
15 changes: 14 additions & 1 deletion src/tinygp/solvers/quasisep/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,18 @@ def scale(self, other: JAXArray) -> StrictLowerTriQSM:

def self_add(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM:
"""The sum of two :class:`StrictLowerTriQSM` matrices"""
from tinygp.solvers.quasisep.block import Block
Comment thread
dfm marked this conversation as resolved.
Outdated

@jax.vmap
def impl(
self: StrictLowerTriQSM, other: StrictLowerTriQSM
) -> StrictLowerTriQSM:
p1, q1, a1 = self
p2, q2, a2 = other
if isinstance(a1, Block):
Comment thread
dfm marked this conversation as resolved.
Outdated
a1 = a1.to_dense()
if isinstance(a2, Block):
a2 = a2.to_dense()
return StrictLowerTriQSM(
p=jnp.concatenate((p1, p2)),
q=jnp.concatenate((q1, q2)),
Expand All @@ -220,13 +225,21 @@ def impl(

def self_mul(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM:
"""The elementwise product of two :class:`StrictLowerTriQSM` matrices"""
from tinygp.solvers.quasisep.block import Block

self_a = self.a
other_a = other.a
if isinstance(self_a, Block):
self_a = jax.vmap(lambda b: b.to_dense())(self_a)
Comment thread
dfm marked this conversation as resolved.
Outdated
if isinstance(other_a, Block):
other_a = jax.vmap(lambda b: b.to_dense())(other_a)
i, j = np.meshgrid(np.arange(self.p.shape[1]), np.arange(other.p.shape[1]))
i = i.flatten()
j = j.flatten()
return StrictLowerTriQSM(
p=self.p[:, i] * other.p[:, j],
q=self.q[:, i] * other.q[:, j],
a=self.a[:, i[:, None], i[None, :]] * other.a[:, j[:, None], j[None, :]],
a=self_a[:, i[:, None], i[None, :]] * other_a[:, j[:, None], j[None, :]],
)

def __neg__(self) -> StrictLowerTriQSM:
Expand Down
34 changes: 24 additions & 10 deletions src/tinygp/solvers/quasisep/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
import jax.numpy as jnp

from tinygp.helpers import JAXArray
from tinygp.solvers.quasisep.block import Block


def _ensure_dense(x: JAXArray) -> JAXArray:
"""Convert Block to dense array if needed."""
if isinstance(x, Block):
return x.to_dense()
return x


from tinygp.solvers.quasisep.core import (
QSM,
DiagQSM,
Expand Down Expand Up @@ -145,15 +155,17 @@ def impl(
u += [upper_b.p] if upper_b is not None else []

if lower_a is not None and lower_b is not None:
la_a = _ensure_dense(lower_a.a)
lb_a = _ensure_dense(lower_b.a)
ell = jnp.concatenate(
(
jnp.concatenate(
(lower_a.a, jnp.outer(lower_a.q, lower_b.p)), axis=-1
(la_a, jnp.outer(lower_a.q, lower_b.p)), axis=-1
),
jnp.concatenate(
(
jnp.zeros((lower_b.a.shape[0], lower_a.a.shape[0])),
lower_b.a,
jnp.zeros((lb_a.shape[0], la_a.shape[0])),
lb_a,
),
axis=-1,
),
Expand All @@ -162,33 +174,35 @@ def impl(
)
else:
ell = (
lower_a.a
_ensure_dense(lower_a.a)
if lower_a is not None
else lower_b.a if lower_b is not None else None
else _ensure_dense(lower_b.a) if lower_b is not None else None
)

if upper_a is not None and upper_b is not None:
ua_a = _ensure_dense(upper_a.a)
ub_a = _ensure_dense(upper_b.a)
delta = jnp.concatenate(
(
jnp.concatenate(
(
upper_a.a,
jnp.zeros((upper_a.a.shape[0], upper_b.a.shape[0])),
ua_a,
jnp.zeros((ua_a.shape[0], ub_a.shape[0])),
),
axis=-1,
),
jnp.concatenate(
(jnp.outer(upper_b.q, upper_a.p), upper_b.a), axis=-1
(jnp.outer(upper_b.q, upper_a.p), ub_a), axis=-1
),
),
axis=0,
)

else:
delta = (
upper_a.a
_ensure_dense(upper_a.a)
if upper_a is not None
else upper_b.a if upper_b is not None else None
else _ensure_dense(upper_b.a) if upper_b is not None else None
)

return (
Expand Down
127 changes: 127 additions & 0 deletions tests/test_kernels/test_quasisep.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,130 @@ def test_carma_quads():
assert_allclose(carma31.arroots, carma31_quads.arroots)
assert_allclose(carma31.acf, carma31_quads.acf)
assert_allclose(carma31.obsmodel, carma31_quads.obsmodel)


# Regression tests for https://github.com/dfm/tinygp/issues/265
# Block transition matrices in Sum kernels broke several operations.


Comment thread
dfm marked this conversation as resolved.
Outdated
def test_sum_kernel_with_banded_noise(data):
"""Sum kernel + banded noise: self_add must handle Block transition matrices"""
Comment thread
dfm marked this conversation as resolved.
Outdated
from tinygp.noise import Banded
Comment thread
dfm marked this conversation as resolved.
Outdated

x, y, _ = data
N = len(x)
k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0)
banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1)))
gp = GaussianProcess(k, x, noise=banded)
assert jnp.isfinite(gp.log_probability(y))


def test_sum_kernel_with_banded_noise_condition(data):
"""Sum kernel + banded noise: conditioning must handle Block in qsm_mul"""
from tinygp.noise import Banded

x, y, _ = data
N = len(x)
k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0)
banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1)))
gp = GaussianProcess(k, x, noise=banded)
lp, cond_gp = gp.condition(y)
assert jnp.isfinite(lp)


def test_product_of_sum_kernel(data):
"""Product kernel with Sum factor: _prod_helper must handle Block inputs"""
x, y, _ = data
N = len(x)
k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0)
gp = GaussianProcess(k, x, diag=jnp.ones(N))
assert jnp.isfinite(gp.log_probability(y))


def test_product_of_sum_kernel_consistency(data):
"""Product of sum kernel QSM must match direct kernel evaluation"""
x, _, _ = data
k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0)
assert_allclose(k.to_symm_qsm(x).to_dense(), k(x, x))


def test_sum_times_sum_kernel(data):
"""Product of two Sum kernels: self_mul must handle Block transition matrices"""
x, y, _ = data
N = len(x)
k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * (
quasisep.Exp(0.5) + quasisep.Matern32(1.0)
)
gp = GaussianProcess(k, x, diag=jnp.ones(N))
assert jnp.isfinite(gp.log_probability(y))


def test_sum_kernel_use_block_false(data):
"""Sum kernel with use_block=False bypasses Block entirely"""
x, y, _ = data
N = len(x)
k = quasisep.Sum(quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False)
gp = GaussianProcess(k, x, diag=0.1 * jnp.ones(N))
assert jnp.isfinite(gp.log_probability(y))


def test_sum_kernel_use_block_consistency(data):
"""Block and non-block Sum kernels must produce the same results"""
x, y, _ = data
N = len(x)
k_block = quasisep.Cosine(1.0) + quasisep.Cosine(2.0)
k_dense = quasisep.Sum(
quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False
)
gp_block = GaussianProcess(k_block, x, diag=0.1 * jnp.ones(N))
gp_dense = GaussianProcess(k_dense, x, diag=0.1 * jnp.ones(N))
assert_allclose(gp_block.log_probability(y), gp_dense.log_probability(y))


def test_sum_kernel_use_block_false_with_banded_noise(data):
"""Sum kernel use_block=False with banded noise"""
from tinygp.noise import Banded

x, y, _ = data
N = len(x)
k = quasisep.Sum(quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False)
banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1)))
gp = GaussianProcess(k, x, noise=banded)
assert jnp.isfinite(gp.log_probability(y))


def test_sum_kernel_use_block_false_product(data):
"""Sum kernel use_block=False in a product"""
x, y, _ = data
N = len(x)
k = quasisep.Sum(
quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False
) * quasisep.Exp(1.0)
gp = GaussianProcess(k, x, diag=jnp.ones(N))
assert jnp.isfinite(gp.log_probability(y))


def test_jit_sum_kernel_block(data):
"""JIT must work with Sum kernel block computations"""
x, y, _ = data
N = len(x)
k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0)

@jax.jit
def compute(x, y):
gp = GaussianProcess(k, x, diag=0.1 * jnp.ones(N))
return gp.log_probability(y)

assert jnp.isfinite(compute(x, y))


def test_grad_product_of_sum_kernel(data):
"""Gradients must work through product of sum kernel"""
x, y, _ = data
N = len(x)

def loss(sigma):
k = (quasisep.Cosine(sigma) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0)
return GaussianProcess(k, x, diag=jnp.ones(N)).log_probability(y)

assert jnp.isfinite(jax.grad(loss)(1.0))
Loading