eztaox.fitter#
This module contains the fitter functions that fits a model to data.
Attributes#
Functions#
|
Fit a model using random search plus local optimization. |
|
Fit a model using a simple optimizer. |
Module Contents#
- random_search(model: eztaox.models.UniVarModel | eztaox.models.MultiVarModel, init_sampler: collections.abc.Callable, prng_key: jax.random.PRNGKey, n_sample: int, n_best: int, *, batch_size: int = 1000, optimizer: optax.GradientTransformation = DEFAULT_ADAM_OPTIMIZER, n_opt_step: int = 1000, max_opt_step: int | None = None, tol: float | None = None, use_value_and_grad_from_state: bool = False, clear_cache_after_opt: bool = False) tuple[dict[str, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray][source]#
Fit a model using random search plus local optimization.
- Parameters:
model (UniVarModel | MultiVarModel) – EzTaoX Light curve model.
init_sampler (Callable) – Function to sample initial parameters.
prng_key (jax.random.PRNGKey) – Random number generator key.
n_sample (int) – Number of random samples to draw.
n_best (int) – Number of best samples (selected based on their likelihod values) to keep for optimization.
batch_size (int, optional) – The batch size used in evaluating likehood of randomly drawn samples. Defaults to 1000.
optimizer (optax.GradientTransformation, optional) – Optimizer used in local optimization. Defaults to optax.adam(1e-2).
n_opt_step (int, optional) – Number of optimization steps per retained sample. Defaults to 1000 for the default adam optimizer.
max_opt_step (int | None, optional) – Maximum number of optimization steps when using the tolerance-based stopping criterion. Defaults to None.
tol (float | None, optional) – Gradient-norm tolerance for early stopping. This criterion is only used when max_opt_step is also provided. Defaults to None.
use_value_and_grad_from_state (bool, optional) – Whether to reuse value and gradients from the optimizer state when available. This is useful for Optax optimizers such as L-BFGS. Defaults to False.
clear_cache_after_opt (bool, optional) – Clear JAX caches after opt. Defaults to False.
- Returns:
Best parameters and their log likelihood.
- Return type:
tuple[dict[str, JAXArray], JAXArray]
- simple_optimizer(model: eztaox.models.UniVarModel | eztaox.models.MultiVarModel, init_sample: dict[str, tinygp.helpers.JAXArray], *, optimizer: optax.GradientTransformation = DEFAULT_ADAM_OPTIMIZER, n_step: int = 3000, use_value_and_grad_from_state: bool = False) tuple[dict[str, tinygp.helpers.JAXArray], tuple[dict[str, tinygp.helpers.JAXArray], tinygp.helpers.JAXArray, dict[str, tinygp.helpers.JAXArray]]][source]#
Fit a model using a simple optimizer.
- Parameters:
model (UniVarModel | MultiVarModel) – EzTaoX Light curve model.
init_sample (dict[str, JAXArray]) – The initial guess of parameters.
optimizer (optax.GradientTransformation) – Optimizer to use.
n_step (int) – Number of optimization steps.
use_value_and_grad_from_state (bool, optional) – Whether to reuse value and gradients from the optimizer state when available. This is useful for Optax optimizers such as L-BFGS. Defaults to False.
- Returns:
tuple[dict[str, JAXArray], tuple[dict[str, JAXArray], JAXArray, dict[str, JAXArray]]]: Best parameters, (parameter history, loss history, gradient history).