eztaox.fitter
=============

.. py:module:: eztaox.fitter

.. autoapi-nested-parse::

   This module contains the fitter functions that fits a model to data.



Attributes
----------

.. autoapisummary::

   eztaox.fitter.DEFAULT_ADAM_OPTIMIZER


Functions
---------

.. autoapisummary::

   eztaox.fitter.random_search
   eztaox.fitter.simple_optimizer


Module Contents
---------------

.. py:data:: DEFAULT_ADAM_OPTIMIZER

.. py:function:: random_search(model: eztaox.models.UniVarModel | eztaox.models.MultiVarModel, init_sampler: collections.abc.Callable, prng_key: jax.random.PRNGKey, n_sample: int, n_best: int, *, batch_size: int = 1000, optimizer: optax.GradientTransformation = DEFAULT_ADAM_OPTIMIZER, n_opt_step: int = 1000, max_opt_step: int | None = None, tol: float | None = None, use_value_and_grad_from_state: bool = False, clear_cache_after_opt: bool = False) -> tuple[dict[str, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray]

   Fit a model using random search plus local optimization.

   :param model: EzTaoX Light curve model.
   :type model: UniVarModel | MultiVarModel
   :param init_sampler: Function to sample initial parameters.
   :type init_sampler: Callable
   :param prng_key: Random number generator key.
   :type prng_key: jax.random.PRNGKey
   :param n_sample: Number of random samples to draw.
   :type n_sample: int
   :param n_best: Number of best samples (selected based on their likelihod values)
                  to keep for optimization.
   :type n_best: int
   :param batch_size: The batch size used in evaluating likehood of
                      randomly drawn samples. Defaults to 1000.
   :type batch_size: int, optional
   :param optimizer: Optimizer used in local
                     optimization. Defaults to optax.adam(1e-2).
   :type optimizer: optax.GradientTransformation, optional
   :param n_opt_step: Number of optimization steps per retained sample.
                      Defaults to 1000 for the default adam optimizer.
   :type n_opt_step: int, optional
   :param max_opt_step: Maximum number of optimization steps when
                        using the tolerance-based stopping criterion. Defaults to None.
   :type max_opt_step: int | None, optional
   :param tol: Gradient-norm tolerance for early stopping.
               This criterion is only used when max_opt_step is also provided. Defaults
               to None.
   :type tol: float | None, optional
   :param use_value_and_grad_from_state: Whether to reuse value and
                                         gradients from the optimizer state when available. This is useful for
                                         Optax optimizers such as L-BFGS. Defaults to False.
   :type use_value_and_grad_from_state: bool, optional
   :param clear_cache_after_opt: Clear JAX caches after opt.
                                 Defaults to False.
   :type clear_cache_after_opt: bool, optional

   :returns: Best parameters and their log likelihood.
   :rtype: tuple[dict[str, JAXArray], JAXArray]


.. py:function:: simple_optimizer(model: eztaox.models.UniVarModel | eztaox.models.MultiVarModel, init_sample: dict[str, tinygp.helpers.JAXArray], *, optimizer: optax.GradientTransformation = DEFAULT_ADAM_OPTIMIZER, n_step: int = 3000, use_value_and_grad_from_state: bool = False) -> tuple[dict[str, tinygp.helpers.JAXArray], tuple[dict[str, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray, dict[str, tinygp.helpers.JAXArray]]]

   Fit a model using a simple optimizer.

   :param model: EzTaoX Light curve model.
   :type model: UniVarModel | MultiVarModel
   :param init_sample: The initial guess of parameters.
   :type init_sample: dict[str, JAXArray]
   :param optimizer: Optimizer to use.
   :type optimizer: optax.GradientTransformation
   :param n_step: Number of optimization steps.
   :type n_step: int
   :param use_value_and_grad_from_state: Whether to reuse value and
                                         gradients from the optimizer state when available. This is useful for
                                         Optax optimizers such as L-BFGS. Defaults to False.
   :type use_value_and_grad_from_state: bool, optional

   :returns: tuple[dict[str, JAXArray], tuple[dict[str, JAXArray], JAXArray,
             dict[str, JAXArray]]]:
             Best parameters, (parameter history, loss history, gradient history).


