"""Light curve models.
A module of light curve models. which are the interface for modeling
uni/multi-band light curves using Gaussian Processes (GPs).
"""
from collections.abc import Callable
from functools import partial
import equinox as eqx
import jax
import jax.flatten_util
import jax.numpy as jnp
import numpyro
import tinygp.kernels as tk
from numpy.typing import NDArray
from tinygp import GaussianProcess
from tinygp.helpers import JAXArray
from eztaox.kernels import direct, quasisep
from eztaox.ts_utils import merge_sort
[docs]
class MultiVarModel(eqx.Module):
"""An interface for modeling multivariate/multi-band time series using GPs.
This interface only takes GP kernels that can be evaluated using the
scalable method of `DFM+17 <https://arxiv.org/abs/1703.09710>`. This
interface allows fitting for a parameterized mean function of the time series,
additional variance to the measurement uncertainty, and time delays between each
uni-variate/single-band time series.
Args:
X (JAXArray|NDArray): Input data containing time and band indices as a tuple.
y (JAXArray|NDArray): Observed data values.
yerr (JAXArray|NDArray): Observational uncertainties.
base_kernel (Quasisep): A GP kernel from the kernels.quasisep module.
n_band (int): An integer number of bands in the input light curve.
multiband_kernel(Quasisep, optional): A multiband kernel specifying the
cross-band covariance, defaults to kernels.quasisep.MultibandLowRank.
mean_func(Callable, optional): A callable mean function for the GP, defaults to
None.
amp_scale_func(Callable, optional): A callable amplitude scaling function,
defaults to None.
lag_func(Callable, optional): A callable function for time delays between bands,
defaults to None.
**kwargs: Additional keyword arguments.
- `zero_mean` (bool): If True, assumes zero-mean GP. Defaults to True.
- `has_jitter` (bool): If True, assumes the input observational erros
are underestimated. Defaults to False.
- `has_lag` (bool): If True, assumes time delays between time series in
each band. Defaults to False.
"""
X: tuple[JAXArray, JAXArray]
y: JAXArray
diag: JAXArray
base_kernel_def: Callable
multiband_kernel: tk.Kernel | tk.quasisep.Wrapper
t_in_bands: list[JAXArray]
concat_inds_in_bands: list[JAXArray]
n_band: int
mean_func: Callable | None
amp_scale_func: Callable | None
lag_func: Callable | None
zero_mean: bool
has_jitter: bool
has_lag: bool
def __init__(
self,
X: tuple[JAXArray | NDArray, JAXArray | NDArray],
y: JAXArray | NDArray,
yerr: JAXArray | NDArray,
base_kernel: tk.Kernel | quasisep.Quasisep,
n_band: int,
*,
multiband_kernel: tk.Kernel | tk.quasisep.Wrapper | None = None,
mean_func: Callable | None = None,
amp_scale_func: Callable | None = None,
lag_func: Callable | None = None,
**kwargs,
) -> None:
# format inputs
t = jnp.asarray(X[0])
band = jnp.asarray(X[1], dtype=int)
y = jnp.asarray(y)
yerr = jnp.asarray(yerr)
init_inds = jnp.argsort(t)
# assign attributes
self.X = (t[init_inds], band[init_inds])
self.diag = (yerr**2)[init_inds]
self.y = y[init_inds]
self.base_kernel_def = jax.flatten_util.ravel_pytree(base_kernel)[1]
self.n_band = n_band
# assign band indexs for sorting the input time axis after lag transform
sorted_t, sorted_band = self.X
unique_bands = jnp.unique(sorted_band)
self.t_in_bands = [
sorted_t[jnp.where(sorted_band == i)[0]] for i in unique_bands
]
self.concat_inds_in_bands = jnp.concat(
[jnp.where(sorted_band == i)[0] for i in unique_bands]
)
# assign callables/classes
if multiband_kernel is None:
if isinstance(base_kernel, tk.quasisep.Quasisep):
multiband_kernel = quasisep.MultibandLowRank
else:
multiband_kernel = direct.MultibandLowRank
self.multiband_kernel = multiband_kernel
self.mean_func = mean_func
self.amp_scale_func = amp_scale_func
self.lag_func = lag_func
# assign other attributes
self.zero_mean = kwargs.get("zero_mean", True)
self.has_jitter = kwargs.get("has_jitter", False)
self.has_lag = kwargs.get("has_lag", False)
[docs]
def get_mean(
self, zero_mean: bool, params: dict[str, JAXArray], X: JAXArray
) -> JAXArray:
"""Return the mean of the GP."""
if zero_mean is True:
mean = 0.0
elif self.mean_func is not None:
mean = self.mean_func(params, X)
else:
mean = self._default_mean_func(params, X)
return mean
[docs]
def get_amp_scale(self, params: dict[str, JAXArray]) -> JAXArray:
"""Return the ampltiude of GP in each individaul band."""
if self.amp_scale_func is not None:
return self.amp_scale_func(params)
return self._default_amp_scale_func(params)
def _lag_transform_fast(
self, has_lag: bool, params: dict[str, JAXArray]
) -> tuple[tuple[JAXArray, JAXArray], JAXArray]:
"""Shift the time axis by the lag in each band. Fast version used in fitting."""
if has_lag is False:
lags = jnp.zeros(self.n_band)
elif self.lag_func is not None:
lags = self.lag_func(params)
else:
lags = self._default_lag_func(params)
t, band = self.X
new_t = t - lags[band]
# use merge sort to get the sorted indices for the new time after lag transform
shifted_t_in_bands = jax.tree_util.tree_map(
lambda time, lag: time - lag, self.t_in_bands, list(lags)
)
inds = self.concat_inds_in_bands[merge_sort(*shifted_t_in_bands)]
return (new_t, band), inds
[docs]
def log_prob(self, params: dict[str, JAXArray]) -> JAXArray:
"""Calculate the log probability of the input parameters.
Args:
params (dict[str, JAXArray]): Model parameters.
Returns:
JAXArray: Log probability of the input parameters.
"""
gp, inds = self._build_gp(params)
return gp.log_probability(y=self.y[inds])
[docs]
def aic(self, params: dict[str, JAXArray]) -> JAXArray:
"""Calculate the Akaike Information Criterion (AIC) for the model.
Args:
params (dict[str, JAXArray]): Maximum likelihood model parameters.
Returns:
JAXArray: AIC value.
"""
k = len(jax.flatten_util.ravel_pytree(params)[0])
gp, inds = self._build_gp(params)
log_likelihood = gp.log_probability(y=self.y[inds])
return 2 * k - 2 * log_likelihood
[docs]
def bic(self, params: dict[str, JAXArray]) -> JAXArray:
"""Calculate the Bayesian Information Criterion (BIC) for the model.
Args:
params (dict[str, JAXArray]): Maximum likelihood model parameters.
Returns:
JAXArray: BIC value.
"""
n = self.y.size
k = len(jax.flatten_util.ravel_pytree(params)[0])
gp, inds = self._build_gp(params)
log_likelihood = gp.log_probability(y=self.y[inds])
return jnp.log(n) * k - 2 * log_likelihood
[docs]
def sample(self, params: dict[str, JAXArray]) -> None:
"""Integrate with numpyro for MCMC sampling.
Args:
params (dict[str, JAXArray]): Model parameters.
"""
gp, inds = self._build_gp(params)
numpyro.sample("gp", gp.numpyro_dist(), obs=self.y[inds])
[docs]
def pred(
self, params: dict[str, JAXArray], X: JAXArray
) -> tuple[JAXArray, JAXArray]:
"""Make conditional GP prediction.
Args:
params (dict[str, JAXArray]): A dictionary containing model parameters.
X (JAXArray): The time and band information for creating the conditional GP
prediction.
Returns:
tuple[JAXArray, JAXArray]: A tuple of the mean GP prediction and its
uncertainty (square root of the predicted variance).
"""
# transform time axis
new_X, _ = self.lag_transform(self.has_lag, params, X)
# build gp, cond
gp, inds = self._build_gp(params)
_, cond = gp.condition(self.y[inds], new_X)
return cond.loc, jnp.sqrt(cond.variance)
def _default_mean_func(self, params: dict[str, JAXArray], X: JAXArray) -> JAXArray:
return jnp.atleast_1d(params["mean"])[X[1]]
def _default_amp_scale_func(self, params: dict[str, JAXArray]) -> JAXArray:
return jnp.insert(jnp.atleast_1d(params["log_amp_scale"]), 0, 0.0)
def _default_lag_func(
self, params: dict[str, JAXArray]
) -> tuple[tuple[JAXArray, JAXArray], JAXArray]:
return jnp.insert(jnp.atleast_1d(params["lag"]), 0, 0.0)
@eqx.filter_jit
def _build_gp(
self, params: dict[str, JAXArray]
) -> tuple[GaussianProcess, JAXArray]:
# log amp + mean
log_amp_scales = self.get_amp_scale(params)
means = partial(self.get_mean, self.zero_mean, params)
# time axis transform: t and band are not sorted,
# inds gives the sorted indices for the new_t
X, inds = self._lag_transform_fast(self.has_lag, params)
t = X[0]
band = X[1]
# add jitter to the diagonal
diags = self.diag
if self.has_jitter is True:
diags = (
self.diag + (jnp.exp(jnp.atleast_1d(params["log_jitter"])) ** 2)[band]
)
# def kernel
new_params = params.copy()
new_params["amplitudes"] = jnp.exp(log_amp_scales)
kernel = self.multiband_kernel(
params=new_params,
kernel=self.base_kernel_def(jnp.exp(new_params["log_kernel_param"])),
)
gp_kwargs = {
"diag": diags[inds],
"mean": means,
}
if isinstance(kernel, tk.quasisep.Quasisep):
gp_kwargs["assume_sorted"] = True
return (
GaussianProcess(
kernel,
(t[inds], band[inds]),
**gp_kwargs,
),
inds,
)
[docs]
class UniVarModel(MultiVarModel):
"""Subclass MultiVarModel for modeling univariate/single-band time series data.
Args:
t (JAXArray|NDArray): Time stamps of the input light curve.
y (JAXArray|NDArray): Observed data values at the corresponding time stamps.
yerr (JAXArray|NDArray): Observational uncertainties.
kernel (Quasisep): A GP kernel from the eztaox.kernels.quasisep module.
mean_func(Callable, optional): A callable mean function for the GP, defaults to
None.
amp_scale_func(Callable, optional): A callable amplitude scaling function,
defaults to None.
**kwargs: Additional keyword arguments.
- `zero_mean` (bool): If True, assumes zero-mean GP. Defaults to True.
- `has_jitter` (bool): If True, assumes the input observational erros
are underestimated. Defaults to False.
"""
def __init__(
self,
t: JAXArray | NDArray,
y: JAXArray | NDArray,
yerr: JAXArray | NDArray,
kernel: tk.Kernel | quasisep.Quasisep,
*,
mean_func: Callable | None = None,
amp_scale_func: Callable | None = None,
**kwargs,
) -> None:
inds = jnp.argsort(jnp.asarray(t))
X = (jnp.asarray(t)[inds], jnp.zeros_like(t, dtype=int))
y = jnp.asarray(y)[inds]
yerr = jnp.asarray(yerr)[inds]
base_kernel = kernel
n_band = 1
has_lag = False
super().__init__(
X,
y,
yerr,
base_kernel,
n_band,
mean_func=mean_func,
amp_scale_func=amp_scale_func,
has_lag=has_lag,
**kwargs,
)
def _default_amp_scale_func(self, params: dict[str, JAXArray]) -> JAXArray:
return jnp.array([0.0])
def _lag_transform_fast(
self,
has_lag: bool,
params: dict[str, JAXArray],
) -> tuple[tuple[JAXArray, JAXArray], JAXArray]:
"""Shift the time axis by the lag in each band.
Args:
has_lag (bool): should we introduce lag?
params (dict): argument to pass to the lag function (callable for
time delays between bands).
Returns:
tuple(tuple(JAXArray, JAXArray), JAXArray): modified times and bands
and indexes of the new times.
"""
return self.X, jnp.arange(self.X[0].size)
[docs]
def pred(self, params, t) -> tuple[JAXArray, JAXArray]:
"""Make conditional GP prediction.
Args:
params (dict[str, JAXArray]): A dictionary containing model parameters.
t (JAXArray): The time information for creating the conditional GP
prediction.
Returns:
tuple[JAXArray, JAXArray]: A tuple of the mean GP prediction and its
uncertainty (square root of the predicted variance).
"""
# build gp, cond
gp, inds = self._build_gp(params)
_, cond = gp.condition(self.y[inds], (t, jnp.zeros_like(t, dtype=int)))
return cond.loc, jnp.sqrt(cond.variance)