Source code for eztaox.kernels.direct
"""Kernels evaluated using a direct apporach.
The likelihood computation follows O(N^3) scaling.
"""
from __future__ import annotations
import jax.numpy as jnp
import tinygp
from tinygp.helpers import JAXArray
[docs]
class MultibandLowRank(tinygp.kernels.Kernel):
"""A multiband kernel implementating a low-rank Kronecker covariance structure.
The specific form of the cross-band Kronecker covariance matrix is
given by Equation 13 of `Gordon et al. (2020)
<https://arxiv.org/pdf/2007.05799>`_. The implementation is inspired
by
`this tinygp tutorial <https://tinygp.readthedocs.io/en/stable/tutorials/quasisep-custom.html#multivariate-quasiseparable-kernels>`_.
"""
params: dict[str, JAXArray]
kernel: tinygp.kernels.Kernel
[docs]
def coord_to_sortable(self, X) -> JAXArray:
"""Extract the time-sortable component of the coordinates."""
return X[0]
[docs]
def evaluate(self, X1, X2) -> JAXArray:
"""Evaluate the kernel at a pair of input coordinates."""
amplitudes = self.params["amplitudes"]
return (
amplitudes[X1[1]] * amplitudes[X2[1]] * self.kernel.evaluate(X1[0], X2[0])
)
[docs]
class MultibandFullRank(tinygp.kernels.Kernel):
"""A multiband kernel implementating the full-rank Kronecker covariance structure.
The specific form of the cross-band Kronecker covariance matrix is given by
Equation 18-20 of `Gordon et al. (2020) <https://arxiv.org/pdf/2007.05799>`_.
The implementation is inspired by `this tinygp tutorial <https://tinygp.readthedocs.io/en/stable/tutorials/quasisep-custom.html#multivariate-quasiseparable-kernels>`_.
.. note::
This kernel is still in development, please use with caution.
"""
core_kernel: tinygp.kernels.Kernel
band_kernel: jnp.ndarray
def __init__(self, kernel, diagonal, off_diagonal) -> None:
ndim = diagonal.size
if off_diagonal.size != ((ndim - 1) * ndim) // 2:
raise ValueError(
"Dimension mismatch: expected "
f"(ndim-1)*ndim/2 = {((ndim - 1) * ndim) // 2} elements in "
f"'off_diagonal'; got {off_diagonal.size}"
)
# Construct the band kernel
factor = jnp.zeros((ndim, ndim))
factor = factor.at[jnp.diag_indices(ndim)].add(diagonal)
factor = factor.at[jnp.tril_indices(ndim, -1)].add(off_diagonal)
self.band_kernel = factor @ factor.T
self.core_kernel = kernel
[docs]
def coord_to_sortable(self, X) -> JAXArray:
"""Extract the time-sortable component of the coordinates."""
return X[0]
[docs]
def evaluate(self, X1, X2) -> JAXArray:
"""Evaluate the kernel at a pair of input coordinates."""
t1, b1 = X1
t2, b2 = X2
return self.band_kernel[b1, b2] * self.core_kernel.evaluate(X1, X2)