eztaox.simulator
================

.. py:module:: eztaox.simulator

.. autoapi-nested-parse::

   Simulator module for multi/uni-variate Gaussian Processes.



Classes
-------

.. autoapisummary::

   eztaox.simulator.MultiVarSim
   eztaox.simulator.UniVarSim


Module Contents
---------------

.. py:class:: MultiVarSim(base_kernel: tinygp.kernels.Kernel | tinygp.kernels.quasisep.Quasisep, min_dt: float, max_dt: float, n_band: int, init_params: dict[str, tinygp.helpers.JAXArray], *, multiband_kernel: tinygp.kernels.Kernel | tinygp.kernels.quasisep.Wrapper | None = None, mean_func: collections.abc.Callable | None = None, amp_scale_func: collections.abc.Callable | None = None, lag_func: collections.abc.Callable | None = None, **kwargs)

   Bases: :py:obj:`equinox.Module`


   An interface for simulating multivariate/mutli-band time series using GPs.

   This interface only takes GP kernels that can be evaluated using the
   scalable method of `DFM+17 <https://arxiv.org/abs/1703.09710>`_. This interface
   allows specifying a parameterized mean function of the time series, cross-band
   covariance, and time delays between each uni-variate/single-band time series.

   :param base_kernel: A GP kernel from the kernels.quasisep module.
   :type base_kernel: Quasisep
   :param min_dt: Minimum time step for the simulation.
   :type min_dt: float
   :param max_dt: Maximum time step (temporal baseline) for the simulation.
   :type max_dt: float
   :param n_band: An integer number of bands in the input light curve.
   :type n_band: int
   :param init_params: Initial parameters for the GP.
   :type init_params: dict[str, JAXArray]
   :param multiband_kernel: A multiband kernel specifying the
                            cross-band covariance, defaults to kernels.quasisep.MultibandLowRank.
   :type multiband_kernel: Quasisep, optional
   :param mean_func: A callable mean function for the GP, defaults to
                     None.
   :type mean_func: Callable, optional
   :param amp_scale_func: A callable amplitude scaling function,
                          defaults to None.
   :type amp_scale_func: Callable, optional
   :param lag_func: A callable function for time delays between bands,
                    defaults to None.
   :type lag_func: Callable, optional
   :param \*\*kwargs: Additional keyword arguments.

                      - `zero_mean` (bool): If True, assumes zero-mean GP. Defaults to True.
                      - `has_lag` (bool): If True, assumes time delays between time series in
                        each band. Defaults to False.


   .. py:method:: full(key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) -> tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray]

      Simulate a multivariate GP time series with uniform time sampling.

      :param key: Random number generator key.
      :type key: jax.random.PRNGKey
      :param params: Light curve model parames.
                     Defaults to None. If None, uses the initial parameters.
      :type params: dict[str, JAXArray] | None, optional

      :returns: Simulated time series in the
                form of (time, band) and the simulated light curve values.
      :rtype: tuple[tuple[JAXArray, JAXArray], JAXArray]



   .. py:method:: random(nRand: int, lc_key: jax.random.PRNGKey, random_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) -> tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray, tinygp.helpers.JAXArray]

      Simulate a multivariate GP time series with random time sampling.

      :param nRand: Number of data points in the simulated time series.
      :type nRand: int
      :param lc_key: Random number generator key for simulating a
                     full light curve with uniform time sampling.
      :type lc_key: jax.random.PRNGKey
      :param random_key: Random number generator key for selecting
                         random data points from the full light curve.
      :type random_key: jax.random.PRNGKey
      :param params: Light curve model parames.
                     Defaults to None. If None, uses the initial parameters.
      :type params: dict[str, JAXArray] | None, optional

      :returns: Simulated time series in the
                form of (time, band) and the simulated light curve values.
      :rtype: tuple[tuple[JAXArray, JAXArray], JAXArray]



   .. py:method:: fixed_input(sim_X: tuple[tinygp.helpers.JAXArray | numpy.typing.NDArray, tinygp.helpers.JAXArray | numpy.typing.NDArray], lc_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) -> tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray, tinygp.helpers.JAXArray]

      Simulate a multivar GP time series with fixed input time and band labels.

      :param sim_X: Input time and band.
      :type sim_X: tuple[JAXArray|NDArray, JAXArray|NDArray]
      :param lc_key: Random number generator key for simulating a
                     full light curve with uniform time sampling.
      :type lc_key: jax.random.PRNGKey
      :param params: Light curve model parames.
                     Defaults to None. If None, uses the initial parameters.
      :type params: dict[str, JAXArray] | None, optional

      :returns: Simulated time series in the
                form of (time, band) and the simulated light curve values.
      :rtype: tuple[tuple[JAXArray, JAXArray], JAXArray]



   .. py:method:: fixed_input_fast(sim_X: tuple[tinygp.helpers.JAXArray | numpy.typing.NDArray, tinygp.helpers.JAXArray | numpy.typing.NDArray], lc_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) -> tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray]

      Simulate a multivar GP time series with fixed input time and band labels.

      This method is faster than `fixed_input` since it only simulates the GP at the
      input times, rather than simulating a full light curve and selecting points that
      match the input times.

      :param sim_X: Input time and band.
      :type sim_X: tuple[JAXArray|NDArray, JAXArray|NDArray]
      :param lc_key: Random number generator key for simulating a
                     full light curve with uniform time sampling.
      :type lc_key: jax.random.PRNGKey
      :param params: Light curve model parames.
                     Defaults to None. If None, uses the initial parameters.
      :type params: dict[str, JAXArray] | None, optional

      :returns: Simulated time series in the
                form of (time, band) and the simulated light curve values.
      :rtype: tuple[tuple[JAXArray, JAXArray], JAXArray]



   .. py:method:: get_mean(zero_mean: bool, params: dict[str, tinygp.helpers.JAXArray], X: tinygp.helpers.JAXArray) -> tinygp.helpers.JAXArray

      Return the mean of the GP.



   .. py:method:: get_amp_scale(params: dict[str, tinygp.helpers.JAXArray]) -> tinygp.helpers.JAXArray

      Return the ampltiude of GP in each individaul band.



   .. py:method:: lag_transform(has_lag: bool, params: dict[str, tinygp.helpers.JAXArray], X: tinygp.helpers.JAXArray) -> tuple[tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray]

      Shift the time axis by the lag in each band.



.. py:class:: UniVarSim(base_kernel: tinygp.kernels.Kernel | tinygp.kernels.quasisep.Quasisep, min_dt: float, max_dt: float, init_params: dict[str, tinygp.helpers.JAXArray], *, mean_func: collections.abc.Callable | None = None, amp_scale_func: collections.abc.Callable | None = None, **kwargs)

   Bases: :py:obj:`MultiVarSim`


   An interface for simulating univariate/single-band GP time series.

   :param base_kernel: A GP kernel from the kernels.quasisep module.
   :type base_kernel: Quasisep
   :param min_dt: Minimum time step for the simulation.
   :type min_dt: float
   :param max_dt: Maximum time step (temporal baseline) for the simulation.
   :type max_dt: float
   :param init_params: Initial parameters for the GP.
   :type init_params: dict[str, JAXArray]
   :param mean_func: A callable mean function for the GP, defaults to
                     None.
   :type mean_func: Callable, optional
   :param amp_scale_func: A callable amplitude scaling function,
                          defaults to None.
   :type amp_scale_func: Callable, optional
   :param \*\*kwargs: Additional keyword arguments.

                      * `zero_mean` (bool): If True, assumes zero-mean GP. Defaults to True.


   .. py:method:: full(key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) -> tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray]

      Simulate a univariate GP time series with uniform time sampling.

      :param key: Random number generator key.
      :type key: jax.random.PRNGKey
      :param params: Light curve model parames.
                     Defaults to None. If None, uses the initial parameters.
      :type params: dict[str, JAXArray] | None, optional

      :returns: Simulated time series in the form of (time,
                light curve values).
      :rtype: tuple[JAXArray, JAXArray]



   .. py:method:: random(nRand: int, lc_key: jax.random.PRNGKey, random_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) -> tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray, tinygp.helpers.JAXArray]

      Simulate a univariate GP time series with random time sampling.

      :param nRand: Number of data points in the simulated time series.
      :type nRand: int
      :param lc_key: Random number generator key for simulating a
                     full light curve with uniform time sampling.
      :type lc_key: jax.random.PRNGKey
      :param random_key: Random number generator key for selecting
                         random data points from the full light curve.
      :type random_key: jax.random.PRNGKey
      :param params: Light curve model parames.
                     Defaults to None. If None, uses the initial parameters.
      :type params: dict[str, JAXArray] | None, optional

      :returns: Simulated time series in the form of (time,
                light curve values).
      :rtype: tuple[JAXArray, JAXArray]



   .. py:method:: fixed_input(sim_t: tinygp.helpers.JAXArray | numpy.typing.NDArray, lc_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) -> tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray]

      Simulate a univariate GP time series with fixed input time.

      :param sim_t: Input time for the simulation.
      :type sim_t: JAXArray | NDArray
      :param lc_key: Random number generator key for simulating a
                     full light curve with uniform time sampling.
      :type lc_key: jax.random.PRNGKey
      :param params: Light curve model parames.
                     Defaults to None. If None, uses the initial parameters.
      :type params: dict[str, JAXArray] | None, optional

      :returns: Simulated time series in the form of (time,
                light curve values).
      :rtype: tuple[JAXArray, JAXArray]



   .. py:method:: fixed_input_fast(sim_t: tinygp.helpers.JAXArray | numpy.typing.NDArray, lc_key: jax.random.PRNGKey, params: dict[str, tinygp.helpers.JAXArray] | None = None) -> tuple[tinygp.helpers.JAXArray, tinygp.helpers.JAXArray]

      Simulate a univariate GP time series with fixed input time.

      This method is faster than `fixed_input` since it only simulates the GP at the
      input times, rather than simulating a full light curve and selecting points that
      match the input times.

      :param sim_t: Input time for the simulation.
      :type sim_t: JAXArray | NDArray
      :param lc_key: Random number generator key for simulating a
                     full light curve with uniform time sampling.
      :type lc_key: jax.random.PRNGKey
      :param params: Light curve model parames.
                     Defaults to None. If None, uses the initial parameters.
      :type params: dict[str, JAXArray] | None, optional

      :returns: Simulated time series in the form of (time,
                light curve values).
      :rtype: tuple[JAXArray, JAXArray]



