Skip to content

Commit f84c8e3

Browse files
dfmclaudepre-commit-ci[bot]
authored
Fix Block matrix handling in Sum kernels and quasiseparable operations (#266)
* Fix Block transition matrices breaking Sum kernel operations (#265) Block diagonal matrices used in Sum kernel transition matrices, stationary covariances, and design matrices were incompatible with several operations: adding QSMs (banded noise), product kernels (_prod_helper indexing), and elementwise multiplication (self_mul fancy indexing). Convert Block to dense in these contexts since the state-space matrices are small. Also add a use_block=False option to Sum for users who want to bypass Block entirely. https://claude.ai/code/session_01Y2ACGEqvh9fTrCzR5WEPuJ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address PR review: extract ensure_dense helper, move imports to top, trim tests - Extract ensure_dense() into block.py as a shared helper - Move all imports to module top level (no lazy imports) - Remove test docstrings and comments - Consolidate tests to minimal set covering the three failure modes + use_block https://claude.ai/code/session_01Y2ACGEqvh9fTrCzR5WEPuJ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 71d87d5 commit f84c8e3

5 files changed

Lines changed: 95 additions & 21 deletions

File tree

src/tinygp/kernels/quasisep.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232
import jax
3333
import jax.numpy as jnp
3434
import numpy as np
35+
from jax.scipy.linalg import block_diag as jsp_block_diag
3536

3637
from tinygp.helpers import JAXArray
3738
from tinygp.kernels.base import Kernel
38-
from tinygp.solvers.quasisep.block import Block
39+
from tinygp.solvers.quasisep.block import Block, ensure_dense
3940
from tinygp.solvers.quasisep.core import DiagQSM, StrictLowerTriQSM, SymmQSM
4041
from tinygp.solvers.quasisep.general import GeneralQSM
4142

@@ -220,20 +221,39 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
220221

221222

222223
class Sum(Quasisep):
223-
"""A helper to represent the sum of two quasiseparable kernels"""
224+
"""A helper to represent the sum of two quasiseparable kernels
225+
226+
Args:
227+
kernel1: The first kernel.
228+
kernel2: The second kernel.
229+
use_block: If ``True`` (default), use :class:`Block` diagonal matrices
230+
for the transition matrices, design matrices, and stationary
231+
covariance. If ``False``, use dense ``block_diag`` representations
232+
instead, which avoids compatibility issues with some operations
233+
(e.g. banded noise, product kernels) at a small performance cost
234+
for the state-space matrices.
235+
"""
224236

225237
kernel1: Quasisep
226238
kernel2: Quasisep
239+
use_block: bool = eqx.field(static=True, default=True)
227240

228241
def coord_to_sortable(self, X: JAXArray) -> JAXArray:
229242
"""We assume that both kernels use the same coordinates"""
230243
return self.kernel1.coord_to_sortable(X)
231244

245+
def _block_or_dense(self, m1: JAXArray, m2: JAXArray) -> JAXArray:
246+
if self.use_block:
247+
return Block(m1, m2)
248+
return jsp_block_diag(m1, m2)
249+
232250
def design_matrix(self) -> JAXArray:
233-
return Block(self.kernel1.design_matrix(), self.kernel2.design_matrix())
251+
return self._block_or_dense(
252+
self.kernel1.design_matrix(), self.kernel2.design_matrix()
253+
)
234254

235255
def stationary_covariance(self) -> JAXArray:
236-
return Block(
256+
return self._block_or_dense(
237257
self.kernel1.stationary_covariance(),
238258
self.kernel2.stationary_covariance(),
239259
)
@@ -247,7 +267,7 @@ def observation_model(self, X: JAXArray) -> JAXArray:
247267
)
248268

249269
def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
250-
return Block(
270+
return self._block_or_dense(
251271
self.kernel1.transition_matrix(X1, X2),
252272
self.kernel2.transition_matrix(X1, X2),
253273
)
@@ -632,6 +652,8 @@ def transition_matrix(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
632652

633653

634654
def _prod_helper(a1: JAXArray, a2: JAXArray) -> JAXArray:
655+
a1 = ensure_dense(a1)
656+
a2 = ensure_dense(a2)
635657
i, j = np.meshgrid(np.arange(a1.shape[0]), np.arange(a2.shape[0]))
636658
i = i.flatten()
637659
j = j.flatten()

src/tinygp/solvers/quasisep/block.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
from tinygp.helpers import JAXArray
1010

1111

12+
def ensure_dense(x: Any) -> Any:
13+
"""Convert a Block to a dense array, passing through non-Block inputs."""
14+
if isinstance(x, Block):
15+
return x.to_dense()
16+
return x
17+
18+
1219
class Block(eqx.Module):
1320
blocks: tuple[Any, ...]
1421
__array_priority__ = 1999

src/tinygp/solvers/quasisep/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from jax.scipy.linalg import block_diag
3030

3131
from tinygp.helpers import JAXArray
32+
from tinygp.solvers.quasisep.block import ensure_dense
3233

3334

3435
def handle_matvec_shapes(
@@ -213,20 +214,24 @@ def impl(
213214
return StrictLowerTriQSM(
214215
p=jnp.concatenate((p1, p2)),
215216
q=jnp.concatenate((q1, q2)),
216-
a=block_diag(a1, a2),
217+
a=block_diag(ensure_dense(a1), ensure_dense(a2)),
217218
)
218219

219220
return impl(self, other)
220221

221222
def self_mul(self, other: StrictLowerTriQSM) -> StrictLowerTriQSM:
222223
"""The elementwise product of two :class:`StrictLowerTriQSM` matrices"""
224+
# vmap is needed because a batched Block has 3D block arrays that
225+
# block_diag (used by to_dense) cannot handle without unbatching.
226+
self_a = jax.vmap(ensure_dense)(self.a)
227+
other_a = jax.vmap(ensure_dense)(other.a)
223228
i, j = np.meshgrid(np.arange(self.p.shape[1]), np.arange(other.p.shape[1]))
224229
i = i.flatten()
225230
j = j.flatten()
226231
return StrictLowerTriQSM(
227232
p=self.p[:, i] * other.p[:, j],
228233
q=self.q[:, i] * other.q[:, j],
229-
a=self.a[:, i[:, None], i[None, :]] * other.a[:, j[:, None], j[None, :]],
234+
a=self_a[:, i[:, None], i[None, :]] * other_a[:, j[:, None], j[None, :]],
230235
)
231236

232237
def __neg__(self) -> StrictLowerTriQSM:

src/tinygp/solvers/quasisep/ops.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import jax.numpy as jnp
99

1010
from tinygp.helpers import JAXArray
11+
from tinygp.solvers.quasisep.block import ensure_dense
1112
from tinygp.solvers.quasisep.core import (
1213
QSM,
1314
DiagQSM,
@@ -145,15 +146,15 @@ def impl(
145146
u += [upper_b.p] if upper_b is not None else []
146147

147148
if lower_a is not None and lower_b is not None:
149+
la_a = ensure_dense(lower_a.a)
150+
lb_a = ensure_dense(lower_b.a)
148151
ell = jnp.concatenate(
149152
(
150-
jnp.concatenate(
151-
(lower_a.a, jnp.outer(lower_a.q, lower_b.p)), axis=-1
152-
),
153+
jnp.concatenate((la_a, jnp.outer(lower_a.q, lower_b.p)), axis=-1),
153154
jnp.concatenate(
154155
(
155-
jnp.zeros((lower_b.a.shape[0], lower_a.a.shape[0])),
156-
lower_b.a,
156+
jnp.zeros((lb_a.shape[0], la_a.shape[0])),
157+
lb_a,
157158
),
158159
axis=-1,
159160
),
@@ -162,33 +163,33 @@ def impl(
162163
)
163164
else:
164165
ell = (
165-
lower_a.a
166+
ensure_dense(lower_a.a)
166167
if lower_a is not None
167-
else lower_b.a if lower_b is not None else None
168+
else ensure_dense(lower_b.a) if lower_b is not None else None
168169
)
169170

170171
if upper_a is not None and upper_b is not None:
172+
ua_a = ensure_dense(upper_a.a)
173+
ub_a = ensure_dense(upper_b.a)
171174
delta = jnp.concatenate(
172175
(
173176
jnp.concatenate(
174177
(
175-
upper_a.a,
176-
jnp.zeros((upper_a.a.shape[0], upper_b.a.shape[0])),
178+
ua_a,
179+
jnp.zeros((ua_a.shape[0], ub_a.shape[0])),
177180
),
178181
axis=-1,
179182
),
180-
jnp.concatenate(
181-
(jnp.outer(upper_b.q, upper_a.p), upper_b.a), axis=-1
182-
),
183+
jnp.concatenate((jnp.outer(upper_b.q, upper_a.p), ub_a), axis=-1),
183184
),
184185
axis=0,
185186
)
186187

187188
else:
188189
delta = (
189-
upper_a.a
190+
ensure_dense(upper_a.a)
190191
if upper_a is not None
191-
else upper_b.a if upper_b is not None else None
192+
else ensure_dense(upper_b.a) if upper_b is not None else None
192193
)
193194

194195
return (

tests/test_kernels/test_quasisep.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from tinygp import GaussianProcess
88
from tinygp.kernels import quasisep
9+
from tinygp.noise import Banded
910
from tinygp.test_utils import assert_allclose
1011

1112

@@ -157,3 +158,41 @@ def test_carma_quads():
157158
assert_allclose(carma31.arroots, carma31_quads.arroots)
158159
assert_allclose(carma31.acf, carma31_quads.acf)
159160
assert_allclose(carma31.obsmodel, carma31_quads.obsmodel)
161+
162+
163+
def test_sum_kernel_with_banded_noise(data):
164+
x, y, _ = data
165+
N = len(x)
166+
k = quasisep.Cosine(1.0) + quasisep.Cosine(2.0)
167+
banded = Banded(diag=0.1 * jnp.ones(N), off_diags=0.01 * jnp.ones((N, 1)))
168+
gp = GaussianProcess(k, x, noise=banded)
169+
assert jnp.isfinite(gp.log_probability(y))
170+
lp, cond_gp = gp.condition(y)
171+
assert jnp.isfinite(lp)
172+
173+
174+
def test_product_of_sum_kernel(data):
175+
x, y, _ = data
176+
k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * quasisep.Exp(1.0)
177+
gp = GaussianProcess(k, x, diag=jnp.ones(len(x)))
178+
assert jnp.isfinite(gp.log_probability(y))
179+
assert_allclose(k.to_symm_qsm(x).to_dense(), k(x, x))
180+
181+
182+
def test_sum_times_sum_kernel(data):
183+
x, y, _ = data
184+
k = (quasisep.Cosine(1.0) + quasisep.Cosine(2.0)) * (
185+
quasisep.Exp(0.5) + quasisep.Matern32(1.0)
186+
)
187+
gp = GaussianProcess(k, x, diag=jnp.ones(len(x)))
188+
assert jnp.isfinite(gp.log_probability(y))
189+
190+
191+
def test_sum_kernel_use_block_false(data):
192+
x, y, _ = data
193+
N = len(x)
194+
k_block = quasisep.Cosine(1.0) + quasisep.Cosine(2.0)
195+
k_dense = quasisep.Sum(quasisep.Cosine(1.0), quasisep.Cosine(2.0), use_block=False)
196+
gp_block = GaussianProcess(k_block, x, diag=0.1 * jnp.ones(N))
197+
gp_dense = GaussianProcess(k_dense, x, diag=0.1 * jnp.ones(N))
198+
assert_allclose(gp_block.log_probability(y), gp_dense.log_probability(y))

0 commit comments

Comments
 (0)