eztaox.fitter#

This module contains the fitter functions that fits a model to data.

Attributes#

Functions#

random_search(→ tuple[dict[str, ...)

Fit a model using random search plus local optimization.

simple_optimizer(→ tuple[dict[str, ...)

Fit a model using a simple optimizer.

Module Contents#

DEFAULT_ADAM_OPTIMIZER[source]#

Fit a model using random search plus local optimization.

Parameters:
  • model (UniVarModel | MultiVarModel) – EzTaoX Light curve model.

  • init_sampler (Callable) – Function to sample initial parameters.

  • prng_key (jax.random.PRNGKey) – Random number generator key.

  • n_sample (int) – Number of random samples to draw.

  • n_best (int) – Number of best samples (selected based on their likelihod values) to keep for optimization.

  • batch_size (int, optional) – The batch size used in evaluating likehood of randomly drawn samples. Defaults to 1000.

  • optimizer (optax.GradientTransformation, optional) – Optimizer used in local optimization. Defaults to optax.adam(1e-2).

  • n_opt_step (int, optional) – Number of optimization steps per retained sample. Defaults to 1000 for the default adam optimizer.

  • max_opt_step (int | None, optional) – Maximum number of optimization steps when using the tolerance-based stopping criterion. Defaults to None.

  • tol (float | None, optional) – Gradient-norm tolerance for early stopping. This criterion is only used when max_opt_step is also provided. Defaults to None.

  • use_value_and_grad_from_state (bool, optional) – Whether to reuse value and gradients from the optimizer state when available. This is useful for Optax optimizers such as L-BFGS. Defaults to False.

  • clear_cache_after_opt (bool, optional) – Clear JAX caches after opt. Defaults to False.

Returns:

Best parameters and their log likelihood.

Return type:

tuple[dict[str, JAXArray], JAXArray]

simple_optimizer(model: eztaox.models.UniVarModel | eztaox.models.MultiVarModel, init_sample: dict[str, tinygp.helpers.JAXArray], *, optimizer: optax.GradientTransformation = DEFAULT_ADAM_OPTIMIZER, n_step: int = 3000, use_value_and_grad_from_state: bool = False) tuple[dict[str, tinygp.helpers.JAXArray], tuple[dict[str, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray, dict[str, tinygp.helpers.JAXArray]]][source]#

Fit a model using a simple optimizer.

Parameters:
  • model (UniVarModel | MultiVarModel) – EzTaoX Light curve model.

  • init_sample (dict[str, JAXArray]) – The initial guess of parameters.

  • optimizer (optax.GradientTransformation) – Optimizer to use.

  • n_step (int) – Number of optimization steps.

  • use_value_and_grad_from_state (bool, optional) – Whether to reuse value and gradients from the optimizer state when available. This is useful for Optax optimizers such as L-BFGS. Defaults to False.

Returns:

tuple[dict[str, JAXArray], tuple[dict[str, JAXArray], JAXArray, dict[str, JAXArray]]]: Best parameters, (parameter history, loss history, gradient history).