[1]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
mpl.rcParams.update(
{
"text.usetex": False,
"axes.labelsize": 20,
"figure.labelsize": 18,
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"figure.constrained_layout.wspace": 0,
"figure.constrained_layout.hspace": 0,
"figure.constrained_layout.h_pad": 0,
"figure.constrained_layout.w_pad": 0,
"axes.linewidth": 1.2,
}
)
import jax
import jax.numpy as jnp
# you should always set this
jax.config.update("jax_enable_x64", True)
Damped Harmonic Oscillator (DHO)#
This notebook demonstrates how to fit the Damped Harmonic Oscillator (DHO) model to a single-band light curve. For multiband fitting, see notebook 02_Multiband.
The DHO model is a second-order continuous-time autoregressive moving average (CARMA) process. It is defined as the solution to the stochastic differential equation
\[\mathrm{d}^{2}x + \alpha_{1}\mathrm{d}x + \alpha_{0}x
= \beta_{0}\mathrm{d}W + \beta_{1}\mathrm{d}\bigl(\mathrm{d}W\bigr),\]
where \(W\) denotes a Wiener process. The coefficients \(\alpha_{0}\) and \(\alpha_{1}\) are the autoregressive parameters, while \(\beta_{0}\) and \(\beta_{1}\) are the moving-average parameters.
Note
The CARMA parameter notation follows the convention of Kelly+14.
1. Light Curve Simulation#
[2]:
from eztaox.kernels.quasisep import CARMA
from eztaox.simulator import UniVarSim
from eztaox.ts_utils import add_noise
[3]:
# CARMA(2,1)/DHO parameters
alphas = jnp.asarray([0.0002, 0.05])
betas = jnp.asarray([0.0006, 0.03])
sim_params = {"log_kernel_param": jnp.log(jnp.hstack([alphas, betas]))}
# simulation configurations
lc_seed = 2
random_seed = 3
noise_seed = 12
min_dt, max_dt = 1, 3650.0
# simulate
s = UniVarSim(CARMA.init(jnp.log(alphas), jnp.log(betas)), min_dt, max_dt, sim_params)
sim_t, sim_y = s.random(
200, jax.random.PRNGKey(lc_seed), jax.random.PRNGKey(random_seed)
)
sim_yerr = jnp.ones_like(sim_t) * 0.04
sim_y_noisy = add_noise(sim_y, sim_yerr, jax.random.PRNGKey(noise_seed))
plt.errorbar(sim_t, sim_y_noisy, sim_yerr, fmt=".")
plt.xlabel("Time (day)")
plt.ylabel("Flux (mag)")
[3]:
Text(0, 0.5, 'Flux (mag)')
2. Fitting#
Here, we demonstrate how to use the UniVarModel for fitting single-band light curves.
[4]:
import numpyro
import numpyro.distributions as dist
from eztaox.fitter import random_search
from eztaox.models import UniVarModel
from numpyro.handlers import seed as numpyro_seed
2.1 Initialize Light Curve Model#
[5]:
zero_mean = False
p = 2 # CARMA p-order
test_params = {"log_kernel_param": jnp.log(np.array([0.1, 1.1, 1.0, 3.0]))}
# define kernel
k = CARMA.init(
jnp.exp(test_params["log_kernel_param"][:p]),
jnp.exp(test_params["log_kernel_param"][p:]),
)
# define univar model
m = UniVarModel(sim_t, sim_y_noisy, sim_yerr, k, zero_mean=zero_mean)
m
[5]:
UniVarModel(
X=(f64[200], i64[200]),
y=f64[200],
diag=f64[200],
base_kernel_def=<jax._src.util.HashablePartial object at 0x74545884f610>,
multiband_kernel=<class 'eztaox.kernels.quasisep.MultibandLowRank'>,
t_in_bands=[f64[200]],
concat_inds_in_bands=i64[200],
n_band=1,
mean_func=None,
amp_scale_func=None,
lag_func=None,
zero_mean=False,
has_jitter=False,
has_lag=False
)
2.2 Define Init Sampler#
[6]:
def init_sampler():
# DHO Alpha & Beta parameters
log_alpha = numpyro.sample(
"log_alpha", dist.Uniform(low=-16.0, high=0.0).expand([2])
)
log_beta = numpyro.sample("log_beta", dist.Uniform(low=-10.0, high=2.0).expand([2]))
log_kernel_param = jnp.hstack([log_alpha, log_beta])
numpyro.deterministic("log_kernel_param", log_kernel_param)
# mean
mean = numpyro.sample("mean", dist.Uniform(low=-0.2, high=0.2))
sample_params = {"log_kernel_param": log_kernel_param, "mean": mean}
return sample_params
[7]:
# generate a random initial guess
sample_key = jax.random.PRNGKey(1)
prior_sample = numpyro_seed(init_sampler, rng_seed=sample_key)()
prior_sample
[7]:
{'log_kernel_param': Array([ -4.87392318, -13.69996053, -3.87823238, -6.84304299], dtype=float64),
'mean': Array(-0.06222013, dtype=float64)}
2.3 MLE Fitting#
[8]:
%%time
model = m
sampler = init_sampler
fit_key = jax.random.PRNGKey(1)
n_sample = 10_000
n_best = 10 # it seems like this number needs to be high
bestP, ll = random_search(model, init_sampler, fit_key, n_sample, n_best)
bestP
CPU times: user 4.21 s, sys: 154 ms, total: 4.36 s
Wall time: 4.25 s
[8]:
{'log_kernel_param': Array([-8.09148664, -2.8628066 , -7.06743361, -3.40787672], dtype=float64),
'mean': Array(0.04581574, dtype=float64)}
[9]:
# True DHO param
# Note that EzTao follows the CARMA notation from Moreno+19,
# and EzTaoX adopts the CARMA notation from Kelly+14.
# The main difference is that the alpha parameter index is reversed.
print("True DHO Params (in natual log):")
print(np.log(np.hstack([alphas, betas])))
print("MLE DHO Params (in natual log):")
print(bestP["log_kernel_param"])
True DHO Params (in natual log):
[-8.51719319 -2.99573227 -7.4185809 -3.5065579 ]
MLE DHO Params (in natual log):
[-8.09148664 -2.8628066 -7.06743361 -3.40787672]
3. MCMC#
[10]:
import arviz as az
from numpyro.infer import MCMC, NUTS
[11]:
def numpyro_model(m):
log_alpha = numpyro.sample(
"log_alpha", dist.Uniform(low=-16.0, high=0.0).expand([2])
)
log_beta = numpyro.sample("log_beta", dist.Uniform(low=-10.0, high=2.0).expand([2]))
log_kernel_param = jnp.hstack([log_alpha, log_beta])
numpyro.deterministic("log_kernel_param", log_kernel_param)
# mean: use a normal prior for better convergence
mean = numpyro.sample("mean", dist.Normal(0.0, 0.1))
sample_params = {"log_kernel_param": log_kernel_param, "mean": mean}
m.sample(sample_params)
[12]:
%%time
# the following is different from the init_sampler
zero_mean = False
p = 2
k = CARMA.init(
jnp.exp(test_params["log_kernel_param"][:p]),
jnp.exp(test_params["log_kernel_param"][p:]),
)
m = UniVarModel(sim_t, sim_y_noisy, sim_yerr, k, zero_mean=zero_mean)
nuts_kernel = NUTS(
numpyro_model,
dense_mass=True,
target_accept_prob=0.9,
init_strategy=numpyro.infer.init_to_sample,
)
mcmc = MCMC(
nuts_kernel,
num_warmup=2000,
num_samples=4000,
num_chains=1,
# progress_bar=False,
)
mcmc_seed = 0
mcmc.run(jax.random.PRNGKey(mcmc_seed), m)
data = az.from_numpyro(mcmc)
mcmc.print_summary()
sample: 100%|██████████| 6000/6000 [00:35<00:00, 168.12it/s, 95 steps of size 2.39e-02. acc. prob=0.96]
mean std median 5.0% 95.0% n_eff r_hat
log_alpha[0] -8.88 2.54 -8.95 -12.52 -4.50 202.38 1.00
log_alpha[1] -2.71 1.19 -3.07 -4.36 -0.54 190.30 1.00
log_beta[0] -7.11 1.73 -7.46 -9.42 -3.78 199.69 1.00
log_beta[1] -3.53 0.82 -3.43 -3.65 -3.02 96.77 1.01
mean 0.03 0.06 0.03 -0.08 0.12 766.08 1.00
Number of divergences: 0
CPU times: user 40 s, sys: 380 ms, total: 40.4 s
Wall time: 40.3 s
Visualize Chains, Posterior Distributions#
[13]:
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
[14]:
az.plot_trace(data, var_names=["log_alpha", "log_beta", "mean"])
plt.subplots_adjust(hspace=0.4)
[15]:
az.plot_pair(
data,
var_names=["log_alpha", "log_beta", "mean"],
reference_values={
"log_alpha 0": np.log(alphas)[0],
"log_alpha 1": np.log(alphas)[1],
"log_beta 0": np.log(betas)[0],
"log_beta 1": np.log(betas)[1],
"mean": 0.0,
},
reference_values_kwargs={"color": "orange", "markersize": 20, "marker": "s"},
kind="scatter",
marginals=True,
textsize=25,
)
plt.subplots_adjust(hspace=0.0, wspace=0.0)
/home/docs/checkouts/readthedocs.org/user_builds/eztaox/envs/stable/lib/python3.11/site-packages/arviz/plots/backends/matplotlib/pairplot.py:85: UserWarning: Argument reference_values does not include reference value for: log_beta, log_alpha
warnings.warn(
4. Second-order Statistics#
[16]:
from eztaox.kernel_stat2 import gpStat2
ts = np.logspace(0, 4)
fs = np.logspace(-4, 0)
[17]:
# get MCMC samples
flatPost = data.posterior.stack(sample=["chain", "draw"])
log_carma_draws = flatPost["log_kernel_param"].values.T
[18]:
# create second-order stat object
dho_k = CARMA.init(alphas, betas)
gpStat2_dho = gpStat2(dho_k)
4.1 Structure Function#
[19]:
# compute sf for MCMC draws
mcmc_sf = jax.vmap(gpStat2_dho.sf, in_axes=(None, 0))(ts, jnp.exp(log_carma_draws))
[20]:
## plot
# ture SF
plt.loglog(ts, gpStat2_dho.sf(ts), c="k", label="True SF", zorder=100, lw=2)
plt.legend(fontsize=15)
# MCMC SFs
for sf in mcmc_sf[::50]:
plt.loglog(ts, sf, c="tab:green", alpha=0.15)
plt.xlabel("Time")
plt.ylabel("SF")
[20]:
Text(0, 0.5, 'SF')
4.1 Power Spectral Density (PSD)#
[21]:
# compute sf for MCMC draws
mcmc_psd = jax.vmap(gpStat2_dho.psd, in_axes=(None, 0))(fs, jnp.exp(log_carma_draws))
[22]:
## plot
# ture PSD
plt.loglog(fs, gpStat2_dho.psd(fs), c="k", label="True PSD", zorder=100, lw=2)
plt.legend(fontsize=15)
# MCMC PSDs
for psd in mcmc_psd[::50]:
plt.loglog(fs, psd, c="tab:green", alpha=0.15)
plt.xlabel("Frequency")
plt.ylabel("PSD")
[22]:
Text(0, 0.5, 'PSD')