[1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
mpl.rcParams.update(
{
"text.usetex": False,
"axes.labelsize": 20,
"figure.labelsize": 18,
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"figure.constrained_layout.wspace": 0,
"figure.constrained_layout.hspace": 0,
"figure.constrained_layout.h_pad": 0,
"figure.constrained_layout.w_pad": 0,
"axes.linewidth": 1.2,
}
)
import jax
import jax.numpy as jnp
# you should always set this
jax.config.update("jax_enable_x64", True)
Multiband GP Fitting#
This notebook present a complete workflow on how to perform multiband fitting using EzTaoX. A damped random walk (DRW) GP kernel is assumed.
[2]:
from eztaox.kernels.quasisep import Exp
from eztaox.simulator import UniVarSim
from eztaox.ts_utils import add_noise
1. Light Curve Simulation#
We first simulate DRW light curves in two different bands. The intrinsic DRW timescales are set to be the same across bands, and the amplitudes are set to differ. We also add a five-day time delay between these two bands.
[3]:
drw_scale = {"g": 100, "r": 100}
drw_sigma = {"g": 0.25, "r": 0.15}
lc_seed = 2
sampling_seed = {"g": 2, "r": 5} # seed for random sampling
noise_seed = {"g": 1, "r": 2} # seed for mocking observational noise
min_dt, max_dt = 1, 3650.0
ts, ys, yerrs, ys_noisy = {}, {}, {}, {}
for band in "gr":
sim_params = {
"log_kernel_param": jnp.log(jnp.asarray([drw_scale[band], drw_sigma[band]]))
}
s = UniVarSim(Exp(*sim_params["log_kernel_param"]), min_dt, max_dt, sim_params)
# simulate light curve, add noise
sim_t, sim_y = s.random(
200,
jax.random.PRNGKey(lc_seed),
jax.random.PRNGKey(sampling_seed[band]),
)
# add to dict
ts[band] = sim_t
ys[band] = sim_y
# add simulated photometric noise
sim_yerr = jnp.ones_like(sim_t) * 0.05
yerrs[band] = sim_yerr
ys_noisy[band] = add_noise(sim_y, sim_yerr, jax.random.PRNGKey(noise_seed[band]))
## add time lag
ts["r"] += 5
for b in "gr":
plt.errorbar(
ts[b][::1], ys_noisy[b][::1], yerrs[b][::1], fmt=".", label=f"{b}-band"
)
plt.xlabel("Time (day)")
plt.ylabel("Flux (mag)")
plt.legend(fontsize=15)
[3]:
<matplotlib.legend.Legend at 0x75d5f5dbb390>
2. Light Curve Formatting#
To fit multi-band data, you need to put the LCs into a specific format. If your LC are stored in a dictionary (with the key being the band name), see example in Section I, you can use the following function to format it. The output are X, y, yerr:
X: A tuple of arrays in the format of (time, band index)
time: An array of time stamps for observations in all bands.
band index: An array of integers, starting with 0. This array has the same size as the time array. Observations with the band index belong to the same band. Band assigned with a band index of 0 is treated as the ‘reference’ band.
y: An array of observed values (from all bands).
yerr: An array observational uncertainties associated with y.
[4]:
from eztaox.ts_utils import formatlc
[5]:
band_index = {"g": 0, "r": 1}
X, y, yerr = formatlc(ts, ys_noisy, yerrs, band_index)
X, y, yerr
[5]:
((Array([ 34., 39., 44., 48., 77., 100., 133., 161., 167.,
185., 188., 277., 278., 285., 293., 301., 309., 310.,
337., 343., 350., 388., 423., 468., 477., 512., 515.,
532., 536., 537., 542., 546., 625., 642., 648., 653.,
691., 692., 750., 778., 779., 795., 822., 846., 881.,
896., 902., 904., 946., 967., 1051., 1090., 1152., 1184.,
1190., 1232., 1269., 1282., 1323., 1339., 1354., 1373., 1375.,
1399., 1434., 1436., 1440., 1456., 1478., 1486., 1487., 1489.,
1492., 1496., 1523., 1528., 1534., 1550., 1571., 1575., 1585.,
1593., 1596., 1605., 1619., 1620., 1624., 1649., 1671., 1748.,
1753., 1816., 1821., 1824., 1853., 1865., 1867., 1928., 1936.,
1937., 1942., 1952., 1971., 1986., 1998., 2001., 2006., 2037.,
2077., 2092., 2101., 2103., 2149., 2195., 2209., 2218., 2236.,
2260., 2265., 2287., 2316., 2320., 2337., 2350., 2401., 2425.,
2452., 2478., 2494., 2503., 2546., 2547., 2555., 2567., 2568.,
2575., 2596., 2597., 2599., 2636., 2645., 2649., 2677., 2694.,
2722., 2755., 2763., 2764., 2835., 2841., 2842., 2843., 2852.,
2865., 2872., 2907., 2911., 2920., 2947., 2963., 2973., 3027.,
3037., 3039., 3041., 3066., 3105., 3106., 3136., 3139., 3154.,
3160., 3165., 3171., 3181., 3201., 3219., 3232., 3252., 3266.,
3272., 3284., 3292., 3293., 3296., 3305., 3330., 3347., 3374.,
3385., 3414., 3430., 3468., 3474., 3479., 3504., 3570., 3588.,
3594., 3631., 7., 15., 58., 66., 83., 99., 101.,
123., 150., 192., 219., 220., 327., 376., 386., 399.,
423., 424., 425., 444., 449., 460., 466., 479., 483.,
486., 542., 564., 571., 576., 616., 622., 636., 645.,
652., 656., 665., 669., 762., 773., 819., 831., 841.,
852., 868., 895., 937., 950., 951., 971., 997., 1025.,
1076., 1098., 1130., 1138., 1139., 1150., 1179., 1202., 1238.,
1272., 1282., 1283., 1288., 1312., 1458., 1481., 1510., 1548.,
1557., 1559., 1560., 1568., 1584., 1602., 1612., 1651., 1671.,
1675., 1677., 1716., 1729., 1751., 1764., 1768., 1776., 1784.,
1817., 1857., 1867., 1868., 1879., 1890., 1904., 1911., 1983.,
1994., 1998., 2042., 2064., 2096., 2102., 2117., 2121., 2127.,
2137., 2143., 2147., 2168., 2186., 2222., 2239., 2276., 2285.,
2339., 2341., 2354., 2363., 2378., 2387., 2396., 2421., 2422.,
2431., 2432., 2434., 2462., 2482., 2483., 2494., 2528., 2529.,
2531., 2552., 2610., 2611., 2617., 2653., 2662., 2663., 2746.,
2752., 2769., 2776., 2781., 2809., 2837., 2845., 2858., 2860.,
2868., 2918., 2919., 2929., 2932., 2958., 2961., 2972., 2974.,
2998., 3018., 3065., 3075., 3077., 3086., 3093., 3103., 3141.,
3143., 3166., 3172., 3173., 3179., 3195., 3214., 3221., 3239.,
3252., 3254., 3269., 3270., 3288., 3306., 3332., 3335., 3389.,
3399., 3413., 3469., 3470., 3471., 3477., 3511., 3544., 3569.,
3590., 3593., 3651., 3654.], dtype=float64),
Array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1], dtype=int64)),
Array([-3.33760434e-03, -1.28343835e-02, 3.80697030e-02, -3.23851653e-02,
5.86047009e-03, 8.39220969e-02, -1.07134596e-02, 2.75278347e-02,
-4.64268887e-02, 1.55938040e-01, 1.95924170e-01, -1.55709786e-01,
-7.05154252e-02, -3.47829627e-02, 2.28895999e-01, 2.83877585e-01,
1.35090561e-01, 1.12947421e-01, -1.74917569e-01, 1.06781768e-02,
1.12013193e-01, 1.72877781e-01, -3.81178581e-02, 1.67530045e-01,
1.99710200e-01, 2.00478916e-01, 1.56970442e-01, -1.25579732e-02,
-2.75592824e-03, -1.05744469e-01, 3.18473036e-03, 1.23577498e-01,
-2.54884876e-01, -2.37555408e-01, -2.87628748e-01, -3.22175990e-01,
-4.04279566e-01, -4.48060402e-01, -2.09317358e-01, -2.26102975e-01,
-3.20057861e-01, -5.92007993e-02, 9.52777074e-02, 3.10453499e-01,
4.19871224e-01, 4.34186753e-01, 3.58954716e-01, 2.88773567e-01,
-7.43351742e-02, 5.70315946e-02, 2.06538089e-01, 2.82175530e-01,
-1.12762997e-01, -2.44115050e-01, -2.07347922e-01, -3.54315829e-01,
-2.19836721e-01, -2.78127668e-01, -4.02303171e-01, -3.86187200e-01,
-2.75834958e-01, -4.48655246e-01, -3.87672234e-01, -4.70108525e-01,
3.84973305e-02, -1.97336142e-02, -7.05571575e-02, -7.34015543e-02,
-1.73797420e-01, -2.86621853e-01, -2.67678410e-01, -1.22608839e-01,
-8.48208949e-02, -1.13067059e-01, -2.49362162e-01, -2.81104537e-01,
-2.17131437e-01, -1.87819971e-01, -2.14860071e-01, -8.28185800e-02,
-1.93588206e-01, -1.78556533e-01, -2.88981715e-01, 4.57764387e-02,
-1.76431004e-02, -2.10901988e-02, -1.28284512e-01, -1.10186652e-01,
-2.00449934e-01, 5.58012203e-02, -2.60150912e-03, -1.58354923e-01,
4.00031638e-02, -1.11496354e-01, -3.00943404e-01, -1.17527435e-01,
-2.09086266e-01, 3.90215281e-01, 4.57462605e-01, 5.64309864e-01,
4.66846432e-01, 6.50659459e-01, 3.51199615e-01, 9.71738334e-02,
2.65818512e-01, 1.13820219e-01, 4.17496145e-02, 1.31473989e-01,
4.34217521e-01, 3.25223298e-01, 5.44159949e-01, 6.66664300e-01,
6.47854694e-01, -9.65945152e-02, -1.04334841e-01, -8.72105644e-02,
-7.35390447e-02, 1.64699163e-01, 2.24471699e-01, -3.06247318e-03,
1.76235065e-01, 2.07548812e-01, 2.04696664e-01, 4.01406110e-01,
4.51761956e-01, 1.86002735e-01, -3.64321488e-02, -1.14772016e-02,
1.33267973e-02, 2.81482673e-02, 5.56356699e-02, 1.50393014e-01,
1.79941320e-01, 4.64117696e-01, 4.76043721e-01, 5.14854323e-01,
4.47847123e-01, 4.47220827e-01, 5.06183028e-01, 2.07443287e-01,
1.55538043e-01, 1.64883556e-01, 2.00908173e-01, 2.36770069e-01,
3.32007857e-01, 3.54598728e-02, 5.11875406e-03, -1.51554419e-01,
-3.02570982e-03, 1.82065861e-01, 5.39790670e-02, 8.47363990e-02,
3.25492753e-01, 4.39099535e-01, 3.65219230e-01, 3.09232010e-01,
3.18949336e-01, 5.28810132e-01, 3.42086818e-01, 4.22977500e-01,
6.29606433e-01, 5.58334858e-01, 6.30264456e-01, 6.40795284e-01,
4.97611864e-01, 1.97876625e-01, 1.17195362e-01, 6.76201907e-02,
1.81250512e-01, 8.54132597e-02, 2.29912790e-01, 2.17478679e-01,
3.50689043e-02, 7.03821085e-02, -2.47556263e-02, -2.48073679e-01,
-3.94310807e-01, -4.66535431e-01, -3.63575377e-01, -4.21535268e-01,
-4.12997744e-01, -3.99671338e-01, -2.98088834e-01, -3.37513492e-01,
-3.22838018e-01, -2.76448085e-01, -5.89149913e-01, -6.35662250e-01,
-6.01301338e-01, -3.18310813e-01, -3.35320454e-01, -3.51669246e-02,
-6.78498942e-02, -1.04413478e-01, -9.44514354e-02, 7.12074521e-02,
-8.46394102e-02, -1.75153188e-01, -4.40937536e-02, -4.18894156e-01,
1.36337689e-01, 1.07682164e-01, -1.27985565e-01, 1.69160503e-01,
-5.53764256e-02, 1.04929437e-01, 3.45160590e-02, 3.88275647e-02,
-9.76607746e-02, 6.57310100e-02, 9.17220432e-02, 2.45231271e-02,
-1.37994987e-03, 1.10125291e-02, -6.27726072e-03, 1.03646755e-01,
1.17636537e-02, 9.57309182e-02, -2.79284207e-03, -1.49749872e-02,
5.13542352e-02, 5.13532252e-02, 2.62194595e-02, 1.19869936e-01,
2.11930611e-01, 1.58813958e-01, -7.32154065e-02, 4.09568765e-02,
-6.47339885e-02, -1.10127552e-01, 4.64010604e-03, -5.17074844e-03,
-2.28511944e-01, -1.85942878e-01, -1.03492776e-01, -1.15653574e-01,
-1.26986564e-01, -3.03878548e-01, -8.24685224e-02, -8.28690526e-02,
-1.59106092e-02, 1.23469025e-01, 7.98305670e-02, 2.53549770e-01,
1.88479387e-01, 2.33939914e-01, 2.46203094e-01, 6.07247609e-02,
3.29463396e-02, -2.82371786e-02, 1.29592747e-01, 1.32466179e-01,
5.17659641e-02, 1.38271992e-01, 3.70413526e-02, 1.54381460e-01,
6.45443104e-02, -8.30944313e-02, 6.49035494e-02, -1.56235776e-01,
-1.77357887e-01, -9.10057549e-02, -2.81100567e-01, -1.25020876e-01,
-1.43383443e-01, -1.81592690e-01, 2.25503214e-03, -2.01213089e-02,
-1.53359109e-01, -2.98959595e-01, -2.84149999e-02, -1.00520124e-01,
-1.22147734e-01, -8.27482890e-02, -1.04002352e-01, -2.03567454e-01,
-1.17944530e-01, -3.22391571e-02, -9.19029046e-02, -6.62715545e-02,
-8.75072617e-02, 1.77048619e-01, 6.64040504e-02, 6.08623399e-02,
1.16973822e-01, 1.31579602e-01, 1.58933695e-01, 2.18631414e-01,
-3.99408132e-02, -2.07158775e-01, 2.78676297e-04, -1.72107112e-02,
-2.54766959e-01, -5.95186347e-02, -9.72301942e-02, 5.95140171e-02,
1.94384722e-01, 1.31409176e-01, 1.07269315e-01, 3.37289182e-02,
1.80939483e-01, 1.90782824e-01, 3.57074047e-01, 2.95128216e-01,
3.07798568e-01, 2.36869790e-01, 2.33669550e-01, 4.47380602e-01,
3.15289125e-01, 1.90488679e-01, 1.33771819e-01, 2.92742846e-02,
1.36502153e-02, 1.02589626e-01, 5.06459446e-02, 1.16723412e-01,
9.54247683e-02, 2.21070415e-01, 3.20977299e-01, 1.76331566e-01,
8.00303513e-02, 1.97714457e-01, 7.62209775e-02, 1.15914770e-01,
8.63870368e-02, 1.14005932e-01, -4.30406929e-03, -1.24128165e-01,
-3.76769283e-02, -3.87059782e-02, 7.73445377e-02, -7.17062332e-02,
-1.46106711e-01, -1.72801809e-02, 1.40241483e-01, 1.55770161e-01,
2.45076161e-01, 1.13326434e-01, 1.39206843e-01, 1.72069038e-01,
1.38315820e-01, 8.55036808e-02, 1.06748653e-01, -6.13723136e-02,
-2.46005892e-01, -2.18045421e-01, -1.08021064e-01, -9.79199611e-03,
6.94427704e-02, 2.89831182e-01, 2.49982456e-01, 2.72941350e-01,
2.62893512e-01, 1.60081508e-01, 3.13788368e-01, 3.22406221e-01,
3.01992878e-01, 3.02858043e-01, 3.16113062e-01, 2.71385798e-01,
3.58630386e-01, 4.28261102e-01, 2.30219815e-01, 1.75564730e-01,
-8.01639888e-04, 2.44669770e-02, 5.93996771e-02, 8.93814390e-02,
3.53366872e-02, 1.19163904e-01, 1.48335324e-01, 1.44445940e-01,
9.38549733e-02, -1.52485027e-02, 3.03152265e-02, -1.00663051e-01,
-1.84552560e-01, -2.06961019e-01, -1.63937545e-01, -1.86163450e-01,
-3.05433985e-01, -3.36795821e-01, -2.34740401e-01, -1.70279403e-01,
-2.91704143e-01, -3.12568539e-01, -1.64315211e-01, -9.67735868e-02,
-8.62040871e-02, -6.97495098e-02, -9.69426759e-02, -1.29127287e-01,
-5.99753445e-02, 7.99358803e-02, 7.41160531e-02, 9.50133632e-02,
-1.08150649e-01, -1.70351718e-01, -9.19312933e-02, -2.20996610e-01], dtype=float64),
Array([0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05,
0.05, 0.05, 0.05, 0.05], dtype=float64))
3. Light Curve Fitting#
[6]:
from eztaox.kernels.quasisep import MultibandLowRank
from eztaox.models import MultiVarModel
from eztaox.fitter import random_search
import numpyro
from numpyro.handlers import seed as numpyro_seed
import numpyro.distributions as dist
3.1 Initialize a light curve model#
[7]:
# define model parameters
has_lag = True # True: Fit for inter-band lag
zero_mean = True # True: DO NOT fit for light curve mean
n_band = 2 # number of bands in the provide light curve (X, y, yerr)
# initialize a GP kernel, note the initial parameters are not used in the fitting
k = Exp(scale=100.0, sigma=1.0)
m = MultiVarModel(X, y, yerr, k, n_band, has_lag=has_lag, zero_mean=zero_mean)
m
[7]:
MultiVarModel(
X=(f64[400], i64[400]),
y=f64[400],
diag=f64[400],
base_kernel_def=<jax._src.util.HashablePartial object at 0x75d5dc794d10>,
multiband_kernel=<class 'eztaox.kernels.quasisep.MultibandLowRank'>,
t_in_bands=[f64[200], f64[200]],
concat_inds_in_bands=i64[400],
n_band=2,
mean_func=None,
amp_scale_func=None,
lag_func=None,
zero_mean=True,
has_jitter=False,
has_lag=True
)
3.2 Maximum Likelihood (MLE) Fitting#
To find the best-fit parameters, one can start at a random point in the parameter space and optimize the likelihood function until it stops changing. This likelihood function given a set of new parameters can be obtained by calling MutliVarModel.log_prob(params). However, I find this approach often stuck in local minima. EzTaoX provides a fitter function (random_search) to alleviate this issue (to some level). The random_search function first does a random search (i.e., evaluate
the likelihood at a large number of randomly chosen positions in the parameter space) and then select a few (defaults to five) positions with the highest likelihood to proceed with additional non-linear optimization (e.g., using L-BFGS-B).
The random_search function takes the following arguments:
model: an instance of
MultiVarModelinit_sampler: a custom function (you need to provide) for generating random samples for the random search step.
prng_key: a JAX random number generator key.
n_sample: number of random samples to draw.
n_best: number of best samples to keep for continued optimization.
batch_size: The number of likelihood to evaluate each time. Defaults to 1000, for simpler models (and if you have enough memory), you can set this to
n_sample.
3.2 Define Init Sampler#
Model Parameters:
log_kernel_param: The parameters of the latent GP process.
log_amp_scale: This parameter characterizes the log of the ratio between the amplitude of the GP in each band relative to the latent GP (i.e., the \(S\) parameter in the kernel function). Since the \(S\) parameter is set to 1 by default,
log_amp_scaleis an array of size M-1, where M is the number of bands.mean: The mean of the light curve in each band, with a size M.
lag: The inter-band lags with respect to the reference band.
lagis any array with a size M-1
[8]:
def init_sampler():
# GP kernel param
log_drw_scale = numpyro.sample(
"log_drw_scale", dist.Uniform(jnp.log(0.01), jnp.log(1000))
)
log_drw_sigma = numpyro.sample(
"log_drw_sigma", dist.Uniform(jnp.log(0.01), jnp.log(10))
)
log_kernel_param = jnp.stack([log_drw_scale, log_drw_sigma])
numpyro.deterministic("log_kernel_param", log_kernel_param)
# parameters to relate the amplitudes in each band
log_amp_scale = numpyro.sample("log_amp_scale", dist.Uniform(-2, 2))
mean = numpyro.sample(
"mean",
dist.Uniform(low=jnp.asarray([-0.1, -0.1]), high=jnp.asarray([0.1, 0.1])),
)
# interband lags
lag = numpyro.sample("lag", dist.Uniform(-10, 10))
sample_params = {
"log_kernel_param": log_kernel_param,
"log_amp_scale": log_amp_scale,
"mean": mean,
"lag": lag,
}
return sample_params
[9]:
# generate a random initial guess
sample_key = jax.random.PRNGKey(1)
prior_sample = numpyro_seed(init_sampler, rng_seed=sample_key)()
prior_sample
[9]:
{'log_kernel_param': Array([-2.19138685, 0.3506587 ], dtype=float64),
'log_amp_scale': Array(-0.62220133, dtype=float64),
'mean': Array([-0.04758097, 0.00985753], dtype=float64),
'lag': Array(-1.42400368, dtype=float64)}
A note on model parameters:#
log_kernel_param: The parameters of the latent GP process.
log_amp_scale: This parameter characterizes the log of the ratio between the amplitude of the GP in each band relative to the latent GP (i.e., the \(S\) parameter in the kernel function). Since the \(S\) parameter is set to 1 by default,
log_amp_scaleis an array of size M-1, where M is the number of bands.mean: The mean of the light curve in each band, with a size M.
lag: The inter-band lags with respect to the reference band.
lagis any array with a size M-1
3.2 Try MLE Fitting#
[10]:
%%time
model = m
sampler = init_sampler
fit_key = jax.random.PRNGKey(1)
n_sample = 1_000
n_best = 5 # it seems like this number needs to be high
bestP, ll = random_search(model, init_sampler, fit_key, n_sample, n_best)
bestP
CPU times: user 3.03 s, sys: 145 ms, total: 3.17 s
Wall time: 3.12 s
[10]:
{'lag': Array(5.00010063, dtype=float64),
'log_amp_scale': Array(-0.52967679, dtype=float64),
'log_kernel_param': Array([ 4.61422257, -1.3001678 ], dtype=float64),
'mean': Array([0.07334045, 0.04293498], dtype=float64)}
4.0 MCMC#
4.1 Define numpyro MCMC model#
[11]:
def numpyro_model(m):
sample_params = init_sampler()
m.sample(sample_params)
4.2 Run MCMC#
[12]:
from numpyro.infer import MCMC, NUTS, init_to_median
import arviz as az
[13]:
%%time
## the following is different from the init_sampler
has_lag = True
zero_mean = True
n_band = 2
k = Exp(scale=100.0, sigma=1.0) # init params for k are not used
m = MultiVarModel(X, y, yerr, k, n_band, has_lag=has_lag, zero_mean=zero_mean)
nuts_kernel = NUTS(
numpyro_model,
dense_mass=True,
target_accept_prob=0.9,
init_strategy=init_to_median,
)
mcmc = MCMC(
nuts_kernel,
num_warmup=500,
num_samples=1000,
num_chains=1,
# progress_bar=False,
)
mcmc_seed = 0
mcmc.run(jax.random.PRNGKey(mcmc_seed), m)
data = az.from_numpyro(mcmc)
mcmc.print_summary()
sample: 100%|██████████| 1500/1500 [00:04<00:00, 366.51it/s, 3 steps of size 5.18e-01. acc. prob=0.93]
mean std median 5.0% 95.0% n_eff r_hat
lag 4.98 0.85 4.98 3.64 6.28 1010.77 1.00
log_amp_scale -0.53 0.04 -0.53 -0.59 -0.46 1768.69 1.00
log_drw_scale 4.72 0.32 4.68 4.20 5.17 676.90 1.00
log_drw_sigma -1.25 0.14 -1.27 -1.46 -1.04 627.40 1.00
mean[0] 0.00 0.06 0.00 -0.08 0.09 1518.36 1.00
mean[1] -0.00 0.06 0.00 -0.08 0.09 1190.29 1.00
Number of divergences: 0
CPU times: user 7.92 s, sys: 298 ms, total: 8.22 s
Wall time: 8.16 s
4.3 Visualize Chains, Posterior Distributions#
[14]:
az.plot_trace(
data, var_names=["log_drw_scale", "log_drw_sigma", "log_amp_scale", "lag"]
)
plt.subplots_adjust(hspace=0.4)
[15]:
az.plot_pair(
data,
var_names=["log_drw_scale", "log_drw_sigma", "log_amp_scale", "lag"],
reference_values={
"log_drw_scale": np.log(drw_scale["g"]),
"log_drw_sigma": np.log(drw_sigma["g"]),
"log_amp_scale": np.log(drw_sigma["r"] / drw_sigma["g"]),
"lag": 5.0,
},
reference_values_kwargs={"color": "orange", "markersize": 20, "marker": "s"},
kind="scatter",
marginals=True,
textsize=25,
)
plt.subplots_adjust(hspace=0.0, wspace=0.0)
5.0 Second-order Statistics#
[16]:
from eztaox.kernel_stat2 import gpStat2
[17]:
ts = np.logspace(-1, 3)
fs = np.logspace(-3, 3)
5.1 Get MCMC Samples#
[18]:
flatPost = data.posterior.stack(sample=["chain", "draw"])
log_drw_draws = flatPost["log_kernel_param"].values.T
log_amp_scale_draws = flatPost["log_amp_scale"].values.T
lag_draws = flatPost["lag"].values.T
5.2 g-band SF#
[19]:
# create a 2nd statistic object using the true g-band kernel
g_drw = Exp(scale=drw_scale["g"], sigma=drw_sigma["g"])
gpStat2_g = gpStat2(g_drw)
# compute sf for MCMC draws
mcmc_sf_g = jax.vmap(gpStat2_g.sf, in_axes=(None, 0))(ts, jnp.exp(log_drw_draws))
[20]:
## plot
# ture SF
plt.loglog(ts, gpStat2_g.sf(ts), c="k", label="True g-band SF", zorder=100, lw=2)
plt.loglog(ts, mcmc_sf_g[0], label="MCMC g-band SF", c="tab:green", alpha=0.8, lw=2)
plt.legend(fontsize=15)
# MCMC SFs
for sf in mcmc_sf_g[::20]:
plt.loglog(ts, sf, c="tab:green", alpha=0.15)
plt.xlabel("Time")
plt.ylabel("SF")
[20]:
Text(0, 0.5, 'SF')
5.3 r-band SF#
[21]:
# create a 2nd statistic object using the true g-band kernel
r_drw = Exp(scale=drw_scale["r"], sigma=drw_sigma["r"])
gpStat2_r = gpStat2(r_drw)
# compute sf for MCMC draws
log_drw_draws_r = log_drw_draws.copy()
log_drw_draws_r[:, 1] += log_amp_scale_draws
mcmc_sf_r = jax.vmap(gpStat2_r.sf, in_axes=(None, 0))(ts, jnp.exp(log_drw_draws_r))
[22]:
## plot
# ture SF
plt.loglog(ts, gpStat2_r.sf(ts), c="k", label="True r-band SF", zorder=100, lw=2)
plt.loglog(ts, mcmc_sf_r[0], label="MCMC r-band SF", c="tab:red", alpha=0.8, lw=2)
plt.legend(fontsize=15)
# MCMC SFs
for sf in mcmc_sf_r[::20]:
plt.loglog(ts, sf, c="tab:red", alpha=0.15)
plt.xlabel("Time")
plt.ylabel("SF")
[22]:
Text(0, 0.5, 'SF')
6.0 Lag distribution#
[23]:
_ = plt.hist(lag_draws, density=True)
plt.vlines(5.0, ymin=0, ymax=1, color="k", lw=2, label="True g-r Lag")
plt.legend(fontsize=15, loc=2)
plt.xlabel("Lag")
plt.ylim(0, 0.5)
[23]:
(0.0, 0.5)