[1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np

mpl.rcParams.update(
    {
        "text.usetex": False,
        "axes.labelsize": 20,
        "figure.labelsize": 18,
        "xtick.labelsize": 16,
        "ytick.labelsize": 16,
        "figure.constrained_layout.wspace": 0,
        "figure.constrained_layout.hspace": 0,
        "figure.constrained_layout.h_pad": 0,
        "figure.constrained_layout.w_pad": 0,
        "axes.linewidth": 1.2,
    }
)

import jax
import jax.numpy as jnp

# you should always set this
jax.config.update("jax_enable_x64", True)

Damped Harmonic Oscillator (DHO)#

This notebook demonstrates how to fit the Damped Harmonic Oscillator (DHO) model to a single-band light curve. For multiband fitting, see notebook 02_Multiband.

The DHO model is a second-order continuous-time autoregressive moving average (CARMA) process. It is defined as the solution to the stochastic differential equation

\[\mathrm{d}^{2}x + \alpha_{1}\mathrm{d}x + \alpha_{0}x = \beta_{0}\mathrm{d}W + \beta_{1}\mathrm{d}\bigl(\mathrm{d}W\bigr),\]

where \(W\) denotes a Wiener process. The coefficients \(\alpha_{0}\) and \(\alpha_{1}\) are the autoregressive parameters, while \(\beta_{0}\) and \(\beta_{1}\) are the moving-average parameters.

Note

The CARMA parameter notation follows the convention of Kelly+14.

1. Light Curve Simulation#

[2]:
from eztaox.kernels.quasisep import CARMA
from eztaox.simulator import UniVarSim
from eztaox.ts_utils import add_noise
[3]:
# CARMA(2,1)/DHO parameters
alphas = jnp.asarray([0.0002, 0.05])
betas = jnp.asarray([0.0006, 0.03])
sim_params = {"log_kernel_param": jnp.log(jnp.hstack([alphas, betas]))}

# simulation configurations
lc_seed = 2
random_seed = 3
noise_seed = 12
min_dt, max_dt = 1, 3650.0

# simulate
s = UniVarSim(CARMA.init(jnp.log(alphas), jnp.log(betas)), min_dt, max_dt, sim_params)
sim_t, sim_y = s.random(
    200, jax.random.PRNGKey(lc_seed), jax.random.PRNGKey(random_seed)
)
sim_yerr = jnp.ones_like(sim_t) * 0.04
sim_y_noisy = add_noise(sim_y, sim_yerr, jax.random.PRNGKey(noise_seed))

plt.errorbar(sim_t, sim_y_noisy, sim_yerr, fmt=".")
plt.xlabel("Time (day)")
plt.ylabel("Flux (mag)")
[3]:
Text(0, 0.5, 'Flux (mag)')
../_images/notebooks_03_DHO_4_1.png

2. Fitting#

Here, we demonstrate how to use the UniVarModel for fitting single-band light curves.

[4]:
import numpyro
import numpyro.distributions as dist
from eztaox.fitter import random_search
from eztaox.models import UniVarModel
from numpyro.handlers import seed as numpyro_seed

2.1 Initialize Light Curve Model#

[5]:
zero_mean = False
p = 2  # CARMA p-order
test_params = {"log_kernel_param": jnp.log(np.array([0.1, 1.1, 1.0, 3.0]))}

# define kernel
k = CARMA.init(
    jnp.exp(test_params["log_kernel_param"][:p]),
    jnp.exp(test_params["log_kernel_param"][p:]),
)

# define univar model
m = UniVarModel(sim_t, sim_y_noisy, sim_yerr, k, zero_mean=zero_mean)
m
[5]:
UniVarModel(
  X=(f64[200], i64[200]),
  y=f64[200],
  diag=f64[200],
  base_kernel_def=<jax._src.util.HashablePartial object at 0x74545884f610>,
  multiband_kernel=<class 'eztaox.kernels.quasisep.MultibandLowRank'>,
  t_in_bands=[f64[200]],
  concat_inds_in_bands=i64[200],
  n_band=1,
  mean_func=None,
  amp_scale_func=None,
  lag_func=None,
  zero_mean=False,
  has_jitter=False,
  has_lag=False
)

2.2 Define Init Sampler#

[6]:
def init_sampler():
    # DHO Alpha & Beta parameters
    log_alpha = numpyro.sample(
        "log_alpha", dist.Uniform(low=-16.0, high=0.0).expand([2])
    )
    log_beta = numpyro.sample("log_beta", dist.Uniform(low=-10.0, high=2.0).expand([2]))

    log_kernel_param = jnp.hstack([log_alpha, log_beta])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # mean
    mean = numpyro.sample("mean", dist.Uniform(low=-0.2, high=0.2))

    sample_params = {"log_kernel_param": log_kernel_param, "mean": mean}
    return sample_params
[7]:
# generate a random initial guess
sample_key = jax.random.PRNGKey(1)
prior_sample = numpyro_seed(init_sampler, rng_seed=sample_key)()
prior_sample
[7]:
{'log_kernel_param': Array([ -4.87392318, -13.69996053,  -3.87823238,  -6.84304299], dtype=float64),
 'mean': Array(-0.06222013, dtype=float64)}

2.3 MLE Fitting#

[8]:
%%time
model = m
sampler = init_sampler
fit_key = jax.random.PRNGKey(1)
n_sample = 10_000
n_best = 10  # it seems like this number needs to be high

bestP, ll = random_search(model, init_sampler, fit_key, n_sample, n_best)
bestP
CPU times: user 4.21 s, sys: 154 ms, total: 4.36 s
Wall time: 4.25 s
[8]:
{'log_kernel_param': Array([-8.09148664, -2.8628066 , -7.06743361, -3.40787672], dtype=float64),
 'mean': Array(0.04581574, dtype=float64)}
[9]:
# True DHO param
# Note that EzTao follows the CARMA notation from Moreno+19,
# and EzTaoX adopts the CARMA notation from Kelly+14.
# The main difference is that the alpha parameter index is reversed.
print("True DHO Params (in natual log):")
print(np.log(np.hstack([alphas, betas])))
print("MLE DHO Params (in natual log):")
print(bestP["log_kernel_param"])
True DHO Params (in natual log):
[-8.51719319 -2.99573227 -7.4185809  -3.5065579 ]
MLE DHO Params (in natual log):
[-8.09148664 -2.8628066  -7.06743361 -3.40787672]

3. MCMC#

[10]:
import arviz as az
from numpyro.infer import MCMC, NUTS
[11]:
def numpyro_model(m):
    log_alpha = numpyro.sample(
        "log_alpha", dist.Uniform(low=-16.0, high=0.0).expand([2])
    )
    log_beta = numpyro.sample("log_beta", dist.Uniform(low=-10.0, high=2.0).expand([2]))

    log_kernel_param = jnp.hstack([log_alpha, log_beta])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # mean: use a normal prior for better convergence
    mean = numpyro.sample("mean", dist.Normal(0.0, 0.1))

    sample_params = {"log_kernel_param": log_kernel_param, "mean": mean}

    m.sample(sample_params)
[12]:
%%time

# the following is different from the init_sampler
zero_mean = False
p = 2

k = CARMA.init(
    jnp.exp(test_params["log_kernel_param"][:p]),
    jnp.exp(test_params["log_kernel_param"][p:]),
)
m = UniVarModel(sim_t, sim_y_noisy, sim_yerr, k, zero_mean=zero_mean)

nuts_kernel = NUTS(
    numpyro_model,
    dense_mass=True,
    target_accept_prob=0.9,
    init_strategy=numpyro.infer.init_to_sample,
)

mcmc = MCMC(
    nuts_kernel,
    num_warmup=2000,
    num_samples=4000,
    num_chains=1,
    # progress_bar=False,
)

mcmc_seed = 0
mcmc.run(jax.random.PRNGKey(mcmc_seed), m)
data = az.from_numpyro(mcmc)
mcmc.print_summary()
sample: 100%|██████████| 6000/6000 [00:35<00:00, 168.12it/s, 95 steps of size 2.39e-02. acc. prob=0.96]

                  mean       std    median      5.0%     95.0%     n_eff     r_hat
log_alpha[0]     -8.88      2.54     -8.95    -12.52     -4.50    202.38      1.00
log_alpha[1]     -2.71      1.19     -3.07     -4.36     -0.54    190.30      1.00
 log_beta[0]     -7.11      1.73     -7.46     -9.42     -3.78    199.69      1.00
 log_beta[1]     -3.53      0.82     -3.43     -3.65     -3.02     96.77      1.01
        mean      0.03      0.06      0.03     -0.08      0.12    766.08      1.00

Number of divergences: 0
CPU times: user 40 s, sys: 380 ms, total: 40.4 s
Wall time: 40.3 s

Visualize Chains, Posterior Distributions#

[13]:
import warnings

warnings.filterwarnings("ignore", category=RuntimeWarning)
[14]:
az.plot_trace(data, var_names=["log_alpha", "log_beta", "mean"])
plt.subplots_adjust(hspace=0.4)
../_images/notebooks_03_DHO_21_0.png
[15]:
az.plot_pair(
    data,
    var_names=["log_alpha", "log_beta", "mean"],
    reference_values={
        "log_alpha 0": np.log(alphas)[0],
        "log_alpha 1": np.log(alphas)[1],
        "log_beta 0": np.log(betas)[0],
        "log_beta 1": np.log(betas)[1],
        "mean": 0.0,
    },
    reference_values_kwargs={"color": "orange", "markersize": 20, "marker": "s"},
    kind="scatter",
    marginals=True,
    textsize=25,
)
plt.subplots_adjust(hspace=0.0, wspace=0.0)
/home/docs/checkouts/readthedocs.org/user_builds/eztaox/envs/stable/lib/python3.11/site-packages/arviz/plots/backends/matplotlib/pairplot.py:85: UserWarning: Argument reference_values does not include reference value for: log_beta, log_alpha
  warnings.warn(
../_images/notebooks_03_DHO_22_1.png

4. Second-order Statistics#

[16]:
from eztaox.kernel_stat2 import gpStat2

ts = np.logspace(0, 4)
fs = np.logspace(-4, 0)
[17]:
# get MCMC samples
flatPost = data.posterior.stack(sample=["chain", "draw"])
log_carma_draws = flatPost["log_kernel_param"].values.T
[18]:
# create second-order stat object
dho_k = CARMA.init(alphas, betas)
gpStat2_dho = gpStat2(dho_k)

4.1 Structure Function#

[19]:
# compute sf for MCMC draws
mcmc_sf = jax.vmap(gpStat2_dho.sf, in_axes=(None, 0))(ts, jnp.exp(log_carma_draws))
[20]:
## plot
# ture SF
plt.loglog(ts, gpStat2_dho.sf(ts), c="k", label="True SF", zorder=100, lw=2)
plt.legend(fontsize=15)
# MCMC SFs
for sf in mcmc_sf[::50]:
    plt.loglog(ts, sf, c="tab:green", alpha=0.15)

plt.xlabel("Time")
plt.ylabel("SF")
[20]:
Text(0, 0.5, 'SF')
../_images/notebooks_03_DHO_29_1.png

4.1 Power Spectral Density (PSD)#

[21]:
# compute sf for MCMC draws
mcmc_psd = jax.vmap(gpStat2_dho.psd, in_axes=(None, 0))(fs, jnp.exp(log_carma_draws))
[22]:
## plot
# ture PSD
plt.loglog(fs, gpStat2_dho.psd(fs), c="k", label="True PSD", zorder=100, lw=2)
plt.legend(fontsize=15)

# MCMC PSDs
for psd in mcmc_psd[::50]:
    plt.loglog(fs, psd, c="tab:green", alpha=0.15)

plt.xlabel("Frequency")
plt.ylabel("PSD")
[22]:
Text(0, 0.5, 'PSD')
../_images/notebooks_03_DHO_32_1.png