eztaox.models
=============

.. py:module:: eztaox.models

.. autoapi-nested-parse::

   Light curve models.

   A module of light curve models. which are the interface for modeling
   uni/multi-band light curves using Gaussian Processes (GPs).



Classes
-------

.. autoapisummary::

   eztaox.models.MultiVarModel
   eztaox.models.UniVarModel


Module Contents
---------------

.. py:class:: MultiVarModel(X: tuple[tinygp.helpers.JAXArray | numpy.typing.NDArray, tinygp.helpers.JAXArray | numpy.typing.NDArray], y: tinygp.helpers.JAXArray | numpy.typing.NDArray, yerr: tinygp.helpers.JAXArray | numpy.typing.NDArray, base_kernel: tinygp.kernels.Kernel | eztaox.kernels.quasisep.Quasisep, n_band: int, *, multiband_kernel: tinygp.kernels.Kernel | tinygp.kernels.quasisep.Wrapper | None = None, mean_func: collections.abc.Callable | None = None, amp_scale_func: collections.abc.Callable | None = None, lag_func: collections.abc.Callable | None = None, **kwargs)

   Bases: :py:obj:`equinox.Module`


   An interface for modeling multivariate/multi-band time series using GPs.

   This interface only takes GP kernels that can be evaluated using the
   scalable method of `DFM+17 <https://arxiv.org/abs/1703.09710>`. This
   interface allows fitting for a parameterized mean function of the time series,
   additional variance to the measurement uncertainty, and time delays between each
   uni-variate/single-band time series.

   :param X: Input data containing time and band indices as a tuple.
   :type X: JAXArray|NDArray
   :param y: Observed data values.
   :type y: JAXArray|NDArray
   :param yerr: Observational uncertainties.
   :type yerr: JAXArray|NDArray
   :param base_kernel: A GP kernel from the kernels.quasisep module.
   :type base_kernel: Quasisep
   :param n_band: An integer number of bands in the input light curve.
   :type n_band: int
   :param multiband_kernel: A multiband kernel specifying the
                            cross-band covariance, defaults to kernels.quasisep.MultibandLowRank.
   :type multiband_kernel: Quasisep, optional
   :param mean_func: A callable mean function for the GP, defaults to
                     None.
   :type mean_func: Callable, optional
   :param amp_scale_func: A callable amplitude scaling function,
                          defaults to None.
   :type amp_scale_func: Callable, optional
   :param lag_func: A callable function for time delays between bands,
                    defaults to None.
   :type lag_func: Callable, optional
   :param \*\*kwargs: Additional keyword arguments.

                      - `zero_mean` (bool): If True, assumes zero-mean GP. Defaults to True.
                      - `has_jitter` (bool): If True, assumes the input observational erros
                        are underestimated. Defaults to False.
                      - `has_lag` (bool): If True, assumes time delays between time series in
                        each band. Defaults to False.


   .. py:method:: get_mean(zero_mean: bool, params: dict[str, tinygp.helpers.JAXArray], X: tinygp.helpers.JAXArray) -> tinygp.helpers.JAXArray

      Return the mean of the GP.



   .. py:method:: get_amp_scale(params: dict[str, tinygp.helpers.JAXArray]) -> tinygp.helpers.JAXArray

      Return the ampltiude of GP in each individaul band.



   .. py:method:: lag_transform(has_lag: bool, params: dict[str, tinygp.helpers.JAXArray], X: tinygp.helpers.JAXArray) -> tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray]

      Shift the time axis by the lag in each band.

      :param has_lag: should we introduce lag?
      :type has_lag: bool
      :param params: argument to pass to the lag function (callable for
                     time delays between bands).
      :type params: dict
      :param X: times and bands
      :type X: JAXArray

      :returns:

                modified times and bands
                    and indexes of the new times.
      :rtype: tuple(tuple(JAXArray, JAXArray), JAXArray)



   .. py:method:: log_prior(params: dict[str, tinygp.helpers.JAXArray]) -> tinygp.helpers.JAXArray

      Calculate the log prior of the input parameters.

      :param params: Model parameters.
      :type params: dict[str, JAXArray]

      :returns: Log prior of the input parameters.
      :rtype: JAXArray



   .. py:method:: log_prob(params: dict[str, tinygp.helpers.JAXArray]) -> tinygp.helpers.JAXArray

      Calculate the log probability of the input parameters.

      :param params: Model parameters.
      :type params: dict[str, JAXArray]

      :returns: Log probability of the input parameters.
      :rtype: JAXArray



   .. py:method:: aic(params: dict[str, tinygp.helpers.JAXArray]) -> tinygp.helpers.JAXArray

      Calculate the Akaike Information Criterion (AIC) for the model.

      :param params: Model parameters.
      :type params: dict[str, JAXArray]

      :returns: AIC value.
      :rtype: JAXArray



   .. py:method:: bic(params: dict[str, tinygp.helpers.JAXArray]) -> tinygp.helpers.JAXArray

      Calculate the Bayesian Information Criterion (BIC) for the model.

      :param params: Model parameters.
      :type params: dict[str, JAXArray]

      :returns: BIC value.
      :rtype: JAXArray



   .. py:method:: sample(params: dict[str, tinygp.helpers.JAXArray]) -> None

      Integrate with numpyro for MCMC sampling.

      :param params: Model parameters.
      :type params: dict[str, JAXArray]



   .. py:method:: pred(params: dict[str, tinygp.helpers.JAXArray], X: tinygp.helpers.JAXArray) -> tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray]

      Make conditional GP prediction.

      :param params: A dictionary containing model parameters.
      :type params: dict[str, JAXArray]
      :param X: The time and band information for creating the conditional GP
                prediction.
      :type X: JAXArray

      :returns: A tuple of the mean GP prediction and its
                uncertainty (square root of the predicted variance).
      :rtype: tuple[JAXArray, JAXArray]



.. py:class:: UniVarModel(t: tinygp.helpers.JAXArray | numpy.typing.NDArray, y: tinygp.helpers.JAXArray | numpy.typing.NDArray, yerr: tinygp.helpers.JAXArray | numpy.typing.NDArray, kernel: tinygp.kernels.Kernel | eztaox.kernels.quasisep.Quasisep, *, mean_func: collections.abc.Callable | None = None, amp_scale_func: collections.abc.Callable | None = None, **kwargs)

   Bases: :py:obj:`MultiVarModel`


   Subclass MultiVarModel for modeling univariate/single-band time series data.

   :param t: Time stamps of the input light curve.
   :type t: JAXArray|NDArray
   :param y: Observed data values at the corresponding time stamps.
   :type y: JAXArray|NDArray
   :param yerr: Observational uncertainties.
   :type yerr: JAXArray|NDArray
   :param kernel: A GP kernel from the eztaox.kernels.quasisep module.
   :type kernel: Quasisep
   :param mean_func: A callable mean function for the GP, defaults to
                     None.
   :type mean_func: Callable, optional
   :param amp_scale_func: A callable amplitude scaling function,
                          defaults to None.
   :type amp_scale_func: Callable, optional
   :param \*\*kwargs: Additional keyword arguments.

                      - `zero_mean` (bool): If True, assumes zero-mean GP. Defaults to True.
                      - `has_jitter` (bool): If True, assumes the input observational erros
                        are underestimated. Defaults to False.


   .. py:method:: pred(params, t) -> tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray]

      Make conditional GP prediction.

      :param params: A dictionary containing model parameters.
      :type params: dict[str, JAXArray]
      :param t: The time information for creating the conditional GP
                prediction.
      :type t: JAXArray

      :returns: A tuple of the mean GP prediction and its
                uncertainty (square root of the predicted variance).
      :rtype: tuple[JAXArray, JAXArray]



