Fitting Rubin DP1 light curves with LSDB#

This notebook demonstrates how to load DP1 data using LSDB and run EzTaoX fitting on DP1 light curves.

Set up a virtual environment#

To run this notebook on the RSP please create a virtual environment with conda as follows:

conda create -n eztaox python=3.12
conda activate eztaox
pip install eztaox lsdb jupyter arviz netCDF4 matplotlib
python -m ipykernel install --user --name eztaox --display-name "EzTaoX (Python 3.12)"

You should now be able to select the EzTaoX kernel from your Jupyter notebook.

[1]:
from upath import UPath

import numpy as np
import matplotlib.pyplot as plt
import arviz as az

import os
import lsdb
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
import jax
import pyarrow as pa

from dask.distributed import Client
from eztaox.kernels.quasisep import Exp
from eztaox.models import MultiVarModel
from eztaox.fitter import random_search
from nested_pandas import NestedDtype
from numpyro.infer import MCMC, NUTS, init_to_median
from numpyro.diagnostics import summary

1. Load DP1 light curves#

First, let’s load the DP1 light curves with LSDB. For this notebook, we will only load long light curves with nDiaSources > 500 and the necessary diaObjectForcedSource columns.

[2]:
dia_object_cat = lsdb.open_catalog(
    "/rubin/lsdb_data/dia_object_collection",
    columns=["ra", "dec", "diaObjectId", "nDiaSources", "diaObjectForcedSource"],
    filters=[["nDiaSources", ">", 500]],
)
print(f"Number of DIA objects in DP1: {len(dia_object_cat):,}")
Number of DIA objects in DP1: 1,089,818
[3]:
dia_object_cat.head()
[3]:
  ra dec diaObjectId nDiaSources diaObjectForcedSource
2528665081895373202 53.122135 -28.344551 609788873886662663 512
parentObjectId coord_ra ... psfMagErr midpointMjdTai
0 53.122135 ... 0.001262 60623.259895
+674 rows ... ... ... ...
2528689212471481767 53.041797 -28.283192 609789629800906756 581
parentObjectId coord_ra ... psfMagErr midpointMjdTai
0 53.041797 ... 0.001272 60623.258521
+731 rows ... ... ... ...
2528689652733440700 53.033514 -28.243268 611253766972244112 567
parentObjectId coord_ra ... psfMagErr midpointMjdTai
0 53.033514 ... 0.001002 60623.258521
+753 rows ... ... ... ...
2528691786073128755 52.898762 -28.193302 611253766972243969 548
parentObjectId coord_ra ... psfMagErr midpointMjdTai
0 52.898762 ... 0.00112 60623.258521
+720 rows ... ... ... ...
2528711291375799141 53.321787 -28.312488 609789492361953287 513
parentObjectId coord_ra ... psfMagErr midpointMjdTai
0 53.321787 ... 0.0009 60623.258521
+653 rows ... ... ... ...
5 rows x 5 columns

1.1 Clean light curves#

We will remove all observations with a flag, and keep only objects with detections in both the r and i bands.

Remove objects with flags#

[4]:
flag_subcols = [
    c for c in dia_object_cat["diaObjectForcedSource"].columns if "flag" in c.lower()
]
dia_object_cat = dia_object_cat.query(
    " and ".join([f"diaObjectForcedSource.{c} != 1" for c in flag_subcols])
)

Select two bands, r and i#

[5]:
dia_object_cat = dia_object_cat.query("diaObjectForcedSource.band in ['r','i']")

Remove NA values#

  • Remove all forced sources with NA mjds/mags.

  • Remove all objects that after filtering have no forced sources.

[6]:
def remove_na_sources(df):
    subcols = [
        f"diaObjectForcedSource.{subcol}"
        for subcol in ["midpointMjdTai", "psfMag", "psfMagErr"]
    ]
    return df.dropna(subset=subcols).dropna(subset="diaObjectForcedSource")


dia_object_cat = dia_object_cat.map_partitions(remove_na_sources)
dia_object_cat
[6]:
lsdb Catalog dia_object_lc:
ra dec diaObjectId nDiaSources diaObjectForcedSource
npartitions=208
Order: 6, Pixel: 130 double[pyarrow] double[pyarrow] int64[pyarrow] int64[pyarrow] nested<parentObjectId: [int64], coord_ra: [dou...
Order: 6, Pixel: 136 ... ... ... ... ...
... ... ... ... ... ...
Order: 11, Pixel: 36833621 ... ... ... ... ...
Order: 7, Pixel: 143884 ... ... ... ... ...
5 out of 140 available columns in the catalog have been loaded lazily, meaning no data has been read, only the catalog schema

We will also clean the light curves by:

  • removing 10-sigma outliers within each night

  • binning the data into nightly epochs by taking the mean magnitude for each night

  • updating the magnitude uncertainties by adding, in quadrature, the mean magnitude error and the within-night standard deviation of the magnitudes

[7]:
def prepare_lightcurves(
    time: np.ndarray,
    band: np.ndarray,
    mag: np.ndarray,
    magerr: np.ndarray,
    outlier_snr: float = 10.0,
    binning: bool = True,
) -> dict[str, np.ndarray]:
    """
    Clean light curves by removing outliers and optionally binning by epoch.

    Args:
        time:        Timestamps (MJD).
        band:        Band labels.
        mag:         Magnitudes.
        magerr:      Magnitude errors.
        outlier_snr: Sigma threshold for outlier rejection.
        binning:     If True, combine observations within the same epoch.
    """
    times, mags, magerrs = {}, {}, {}

    for b in np.unique(band):
        band_mask = band == b
        time_band, mag_band, magerr_band = (
            time[band_mask],
            mag[band_mask],
            magerr[band_mask],
        )

        time_band, mag_band, magerr_band = remove_outliers(
            time_band, mag_band, magerr_band, outlier_snr
        )

        if binning:
            t, _, m, me = bin_band(time_band, mag_band, magerr_band, b)
        else:
            t, _, m, me = time_band, np.full(len(time_band), b), mag_band, magerr_band

        times[b] = t
        mags[b] = m
        magerrs[b] = me

    return format_lightcurves(times, mags, magerrs)


def remove_outliers(
    time_band: np.ndarray,
    mag_band: np.ndarray,
    magerr_band: np.ndarray,
    outlier_snr: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Remove data points beyond outlier_snr standard deviations from the mean."""
    if len(time_band) <= 1 or np.nanstd(mag_band) == 0:
        return time_band, mag_band, magerr_band

    sigma = np.abs(mag_band - np.nanmean(mag_band)) / np.nanstd(mag_band)
    mask = sigma < outlier_snr
    return time_band[mask], mag_band[mask], magerr_band[mask]


def bin_band(
    time_band: np.ndarray,
    mag_band: np.ndarray,
    magerr_band: np.ndarray,
    band_label: str,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Bin observations by epoch (rounded to nearest day) and combine magnitudes."""
    rounded_times = np.round(time_band)
    unique_epochs = np.unique(rounded_times)

    binned_times, binned_mags, binned_magerrs = [], [], []
    for epoch in unique_epochs:
        epoch_mask = rounded_times == epoch
        tmp_mag, tmp_err = combine_mag(mag_band[epoch_mask], magerr_band[epoch_mask])
        binned_times.append(epoch)
        binned_mags.append(tmp_mag)
        binned_magerrs.append(tmp_err)

    binned_bands = np.full(len(unique_epochs), band_label)
    return (
        np.array(binned_times),
        binned_bands,
        np.array(binned_mags),
        np.array(binned_magerrs),
    )


def combine_mag(mag, mag_err):
    """Add in quaduature mean(magerr) and std(mag) in each epoch
    ## e.g., https://arxiv.org/abs/2411.06617"""
    mag = np.asarray(mag)
    mag_err = np.asarray(mag_err)
    mag_w = np.nanmean(mag)
    mag_w_err = np.sqrt(np.nanmean(mag_err) ** 2 + np.nanstd(mag) ** 2)
    return mag_w, mag_w_err


def format_lightcurves(times, mags, magerrs):
    """Prepare light curves for fitting"""
    bands = times.keys()
    inds = jnp.argsort(jnp.concatenate([times[b] for b in bands]))
    X = (
        jnp.concatenate([times[b] for b in bands])[inds],
        jnp.concatenate(
            [i * jnp.ones_like(times[b], dtype=int) for i, b in enumerate(bands)]
        )[inds],
    )
    for b in bands:
        mags[b] = jnp.array(mags[b])
        mags[b] -= jnp.median(mags[b])
    y = jnp.concatenate([jnp.array(mags[b]) for b in bands])[inds]
    yerr = jnp.concatenate([jnp.array(magerrs[b]) for b in bands])[inds]
    return X, y, yerr

2. EzTaoX fitting#

First, we need to define a init sampler:

[8]:
def init_sampler(fixed_params=None):
    if fixed_params is None:
        fixed_params = {}

    if "log_drw_scale" in fixed_params:
        log_drw_scale = fixed_params["log_drw_scale"]
    else:
        log_drw_scale = numpyro.sample(
            "log_drw_scale", dist.Uniform(jnp.log(10), jnp.log(100))
        )
    if "log_drw_sigma" in fixed_params:
        log_drw_sigma = fixed_params["log_drw_sigma"]
    else:
        log_drw_sigma = numpyro.sample(
            "log_drw_sigma", dist.Uniform(jnp.log(1e-2), jnp.log(10))
        )
    log_kernel_param = jnp.stack([log_drw_scale, log_drw_sigma])
    numpyro.deterministic("log_kernel_param", log_kernel_param)

    # parameters to relate the amplitudes in each band
    log_amp_scale = numpyro.sample("log_amp_scale", dist.Uniform(-2, 2))

    mean = numpyro.sample(
        "mean",
        dist.Uniform(low=jnp.asarray([-0.1, -0.1]), high=jnp.asarray([0.1, 0.1])),
    )

    # interband lags
    lag = numpyro.sample("lag", dist.Uniform(-10, 10))

    sample_params = {
        "log_kernel_param": log_kernel_param,
        "log_amp_scale": log_amp_scale,
        "mean": mean,
        "lag": lag,
    }
    return sample_params


def numpyro_model(m, fixed_params=None):
    sample_params = init_sampler(fixed_params)
    m.sample(sample_params)

And instantiate a multivariate model to use for fitting:

[9]:
def run_MCMC(
    objid,
    X,
    y,
    yerr,
    numpyro_model,
    has_lag=True,
    zero_mean=True,
    fixed_params=None,
    save_chains=False,
):
    n_band = len(np.unique(X[1]))
    k = Exp(scale=100.0, sigma=1.0)
    m = MultiVarModel(X, y, yerr, k, n_band, has_lag=has_lag, zero_mean=zero_mean)

    nuts_kernel = NUTS(
        numpyro_model,
        dense_mass=True,
        target_accept_prob=0.9,
        init_strategy=init_to_median,
    )
    mcmc = MCMC(
        nuts_kernel,
        num_warmup=500,
        num_samples=1000,
        num_chains=1,
        # progress_bar=False,
    )
    mcmc_seed = jax.random.PRNGKey(0)
    mcmc.run(mcmc_seed, m, fixed_params)

    if save_chains:
        idata = az.from_numpyro(mcmc)
        output_path = UPath(f"mcmc_chains/{objid}.nc")
        output_path.parent.mkdir(parents=True, exist_ok=True)
        idata.to_netcdf(str(output_path))

    return summary(mcmc.get_samples(group_by_chain=True))

Let’s finally put these steps all together:

[10]:
def run_eztaox(objid, times, band, mag, magerr):
    X, y, yerr = prepare_lightcurves(times, band, mag, magerr)
    mcmc_summary = run_MCMC(
        objid,
        X,
        y,
        yerr,
        numpyro_model,
        has_lag=True,
        zero_mean=True,
        fixed_params={"log_drw_scale": jnp.log(100.0)},
        save_chains=True,
    )
    return make_nested(mcmc_summary)


def make_nested(summary):
    nested_dict = {}
    for key in summary:
        for key2 in summary[key]:
            nested_dict[f"{key}.{key2}"] = [summary[key][key2]]
    return nested_dict

Plan the computation and run it in parallel with LSDB:

[11]:
# Stats returned by the mcmc summary
nested_subcols = ["mean", "std", "median", "5.0%", "95.0%", "n_eff", "r_hat"]

# Params returned by the mcmc summary
nested_cols = ["lag", "log_amp_scale", "log_drw_sigma", "log_kernel_param", "mean"]

# Metadata for the catalog operation
meta = {
    col: NestedDtype({sc: pa.float64() for sc in nested_subcols}) for col in nested_cols
}
[12]:
result = dia_object_cat.map_rows(
    run_eztaox,
    # We only need the forced sources time band and mag info
    columns=["diaObjectId"]
    + [
        f"diaObjectForcedSource.{col}"
        for col in ["midpointMjdTai", "band", "psfMag", "psfMagErr"]
    ],
    row_container="args",
    append_columns=True,
    meta=meta,
)
result
[12]:
lsdb Catalog dia_object_lc:
ra dec diaObjectId nDiaSources diaObjectForcedSource lag log_amp_scale log_drw_sigma log_kernel_param mean
npartitions=208
Order: 6, Pixel: 130 double[pyarrow] double[pyarrow] int64[pyarrow] int64[pyarrow] nested<parentObjectId: [int64], coord_ra: [dou... nested<mean: [double], std: [double], median: ... nested<mean: [double], std: [double], median: ... nested<mean: [double], std: [double], median: ... nested<mean: [double], std: [double], median: ... nested<mean: [double], std: [double], median: ...
Order: 6, Pixel: 136 ... ... ... ... ... ... ... ... ... ...
... ... ... ... ... ... ... ... ... ... ...
Order: 11, Pixel: 36833621 ... ... ... ... ... ... ... ... ... ...
Order: 7, Pixel: 143884 ... ... ... ... ... ... ... ... ... ...
10 out of 140 available columns in the catalog have been loaded lazily, meaning no data has been read, only the catalog schema
[13]:
with Client(n_workers=4):
    fitting_df = result.compute()
fitting_df
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 314.28it/s, 15 steps of size 2.66e-01. acc. prob=0.93]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 300.81it/s, 31 steps of size 1.41e-01. acc. prob=0.89]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 290.07it/s, 31 steps of size 1.00e-01. acc. prob=0.88]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 314.45it/s, 31 steps of size 9.57e-02. acc. prob=0.90]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 275.77it/s, 15 steps of size 3.07e-01. acc. prob=0.80]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 319.29it/s, 15 steps of size 3.23e-01. acc. prob=0.94]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 294.03it/s, 15 steps of size 3.45e-01. acc. prob=0.92]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 318.78it/s, 23 steps of size 1.69e-01. acc. prob=0.89]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 295.49it/s, 15 steps of size 2.51e-01. acc. prob=0.95]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 336.88it/s, 15 steps of size 3.96e-01. acc. prob=0.94]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 297.57it/s, 7 steps of size 3.11e-01. acc. prob=0.94]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 314.65it/s, 15 steps of size 2.33e-01. acc. prob=0.90]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 294.25it/s, 15 steps of size 1.40e-01. acc. prob=0.87]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 336.71it/s, 15 steps of size 3.22e-01. acc. prob=0.96]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 266.09it/s, 15 steps of size 1.95e-01. acc. prob=0.89]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 325.71it/s, 15 steps of size 2.77e-01. acc. prob=0.93]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 326.42it/s, 15 steps of size 3.72e-01. acc. prob=0.95]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 321.90it/s, 15 steps of size 2.99e-01. acc. prob=0.88]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 329.05it/s, 15 steps of size 2.52e-01. acc. prob=0.86]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 279.51it/s, 63 steps of size 3.37e-02. acc. prob=0.95]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 317.96it/s, 15 steps of size 3.85e-01. acc. prob=0.89]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 303.39it/s, 47 steps of size 8.71e-02. acc. prob=0.67]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 314.01it/s, 31 steps of size 1.46e-01. acc. prob=0.91]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 313.72it/s, 15 steps of size 2.89e-01. acc. prob=0.90]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 314.00it/s, 15 steps of size 2.89e-01. acc. prob=0.94]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 271.44it/s, 31 steps of size 1.19e-01. acc. prob=0.88]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 316.82it/s, 31 steps of size 1.15e-01. acc. prob=0.91]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 289.87it/s, 15 steps of size 3.51e-01. acc. prob=0.89]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 273.83it/s, 31 steps of size 9.26e-02. acc. prob=0.90]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 307.49it/s, 15 steps of size 2.21e-01. acc. prob=0.90]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 316.89it/s, 15 steps of size 2.47e-01. acc. prob=0.83]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 305.67it/s, 15 steps of size 2.03e-01. acc. prob=0.94]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 301.05it/s, 15 steps of size 1.24e-01. acc. prob=0.89]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 290.34it/s, 15 steps of size 3.56e-01. acc. prob=0.94]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 305.60it/s, 15 steps of size 2.96e-01. acc. prob=0.97]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 255.56it/s, 63 steps of size 4.14e-02. acc. prob=0.88]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 271.16it/s, 7 steps of size 4.05e-01. acc. prob=0.95]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 282.01it/s, 15 steps of size 3.02e-01. acc. prob=0.93]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 288.17it/s, 15 steps of size 1.44e-01. acc. prob=0.91]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 268.26it/s, 15 steps of size 2.06e-01. acc. prob=0.92]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:06<00:00, 230.25it/s, 63 steps of size 6.94e-02. acc. prob=0.92]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 315.54it/s, 7 steps of size 4.22e-01. acc. prob=0.93]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 287.65it/s, 15 steps of size 3.45e-01. acc. prob=0.86]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 281.77it/s, 63 steps of size 7.32e-02. acc. prob=0.89]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:06<00:00, 241.85it/s, 31 steps of size 8.29e-02. acc. prob=0.82]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 289.00it/s, 63 steps of size 5.76e-02. acc. prob=0.93]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 293.75it/s, 31 steps of size 1.06e-01. acc. prob=0.89]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 286.74it/s, 15 steps of size 2.78e-01. acc. prob=0.89]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 260.07it/s, 15 steps of size 2.64e-01. acc. prob=0.91]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:04<00:00, 306.91it/s, 31 steps of size 1.44e-01. acc. prob=0.88]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 286.07it/s, 31 steps of size 1.69e-01. acc. prob=0.89]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:06<00:00, 225.07it/s, 127 steps of size 3.00e-02. acc. prob=0.91]
sample: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1500/1500 [00:05<00:00, 289.89it/s, 15 steps of size 3.34e-01. acc. prob=0.95]
[13]:
  ra dec diaObjectId nDiaSources diaObjectForcedSource lag log_amp_scale log_drw_sigma log_kernel_param mean
2528665081895373202 53.122135 -28.344551 609788873886662663 512
parentObjectId coord_ra ... psfMagErr midpointMjdTai
0 53.122135 ... 0.001262 60623.259895
+306 rows ... ... ... ...
mean std ... n_eff r_hat
0.114068 5.737555 ... 921.557033 0.999944
+0 rows ... ... ... ...
mean std ... n_eff r_hat
-1.209215 0.57737 ... 1025.835509 0.999001
+0 rows ... ... ... ...
mean std ... n_eff r_hat
-4.3326 0.231873 ... 952.425297 1.004991
+0 rows ... ... ... ...
mean std ... n_eff r_hat
[4.6051701859881495, -4.332599512424976] [5.776048474655431e-14, 0.23187307514124833] ... [1.0000000000000004, 952.4252971320903] [0.9989994994993742, 1.0049914362260401]
+0 rows ... ... ... ...
mean std ... n_eff r_hat
[0.0012089694151780491, -0.003091259171689446] [0.05914417560474353, 0.05687344004893575] ... [864.1404257591242, 1036.9850740283177] [0.999348063941459, 0.9994561479780022]
+0 rows ... ... ... ...
2528689212471481767 53.041797 -28.283192 609789629800906756 581
parentObjectId coord_ra ... psfMagErr midpointMjdTai
0 53.041797 ... 0.001272 60623.258521
+330 rows ... ... ... ...
mean std ... n_eff r_hat
0.578621 5.238823 ... 722.733225 1.001008
+0 rows ... ... ... ...
mean std ... n_eff r_hat
-0.832439 0.667594 ... 518.302736 1.000893
+0 rows ... ... ... ...
mean std ... n_eff r_hat
-4.278913 0.273341 ... 1052.19823 0.999355
+0 rows ... ... ... ...
mean std ... n_eff r_hat
[4.6051701859881495, -4.278912857472897] [5.776048474655431e-14, 0.27334118918379074] ... [1.0000000000000004, 1052.1982297702168] [0.9989994994993742, 0.999355132898022]
+0 rows ... ... ... ...
mean std ... n_eff r_hat
[-0.003840854455763379, -0.0021847650087133417] [0.05936029665788383, 0.059226542424557396] ... [604.4243005023516, 775.7899860016753] [0.9991003689784849, 0.9997898716327556]
+0 rows ... ... ... ...
2528689652733440700 53.033514 -28.243268 611253766972244112 567
parentObjectId coord_ra ... psfMagErr midpointMjdTai
0 53.033514 ... 0.001002 60623.258521
+324 rows ... ... ... ...
mean std ... n_eff r_hat
-2.401319 6.975147 ... 67.958886 1.017161
+0 rows ... ... ... ...
mean std ... n_eff r_hat
-1.107034 0.651808 ... 569.411105 0.999006
+0 rows ... ... ... ...
mean std ... n_eff r_hat
-4.039378 0.365575 ... 464.974843 1.000199
+0 rows ... ... ... ...
mean std ... n_eff r_hat
[4.6051701859881495, -4.039378154295635] [5.776048474655431e-14, 0.3655748437373966] ... [1.0000000000000004, 464.9748425687322] [0.9989994994993742, 1.0001986432354808]
+0 rows ... ... ... ...
mean std ... n_eff r_hat
[0.0020242590127864776, -0.004358119511921289] [0.059489835452456834, 0.057878697455887855] ... [408.4726031797575, 686.4980121155778] [0.9998320050192685, 1.0003391587456807]
+0 rows ... ... ... ...
2528691786073128755 52.898762 -28.193302 611253766972243969 548
parentObjectId coord_ra ... psfMagErr midpointMjdTai
0 52.898762 ... 0.00112 60623.258521
+312 rows ... ... ... ...
mean std ... n_eff r_hat
0.88669 4.814158 ... 329.123921 0.999832
+0 rows ... ... ... ...
mean std ... n_eff r_hat
0.504063 0.349293 ... 367.203388 0.99903
+0 rows ... ... ... ...
mean std ... n_eff r_hat
-4.079443 0.334058 ... 539.260457 0.999356
+0 rows ... ... ... ...
mean std ... n_eff r_hat
[4.6051701859881495, -4.079442819364329] [5.776048474655431e-14, 0.3340577456871409] ... [1.0000000000000004, 539.2604566445316] [0.9989994994993742, 0.9993555829694125]
+0 rows ... ... ... ...
mean std ... n_eff r_hat
[0.00022047554581587934, -0.00039017875662901647] [0.05707533604165101, 0.056771575646229276] ... [663.3667144558057, 436.07187332661283] [0.9993875934485403, 1.0098627900396024]
+0 rows ... ... ... ...
2528711291375799141 53.321787 -28.312488 609789492361953287 513
parentObjectId coord_ra ... psfMagErr midpointMjdTai
0 53.321787 ... 0.0009 60623.258521
+270 rows ... ... ... ...
mean std ... n_eff r_hat
-0.173037 5.67528 ... 697.258419 0.999034
+0 rows ... ... ... ...
mean std ... n_eff r_hat
-1.167998 0.589487 ... 613.311294 1.000549
+0 rows ... ... ... ...
mean std ... n_eff r_hat
-3.898214 0.346184 ... 501.619598 0.999179
+0 rows ... ... ... ...
mean std ... n_eff r_hat
[4.6051701859881495, -3.8982144997039527] [5.776048474655431e-14, 0.34618391046486724] ... [1.0000000000000004, 501.61959785318703] [0.9989994994993742, 0.9991790199978499]
+0 rows ... ... ... ...
mean std ... n_eff r_hat
[-0.0011862661677975652, -0.0019238626717922127] [0.05942104000746946, 0.05543160587248729] ... [873.2267154744463, 799.3201322711099] [0.9990252927949145, 0.9991023411032524]
+0 rows ... ... ... ...
... ... ... ... ... ... ...
53 rows x 10 columns

We can load one of the chains as follows:

[14]:
az.from_netcdf("mcmc_chains/611254385447534607.nc")
[14]:
<xarray.DataTree>
Group: /
β”œβ”€β”€ Group: /posterior
β”‚       Dimensions:                 (chain: 1, draw: 1000, log_kernel_param_dim_0: 2,
β”‚                                    mean_dim_0: 2)
β”‚       Coordinates:
β”‚         * chain                   (chain) int64 8B 0
β”‚         * draw                    (draw) int64 8kB 0 1 2 3 4 5 ... 995 996 997 998 999
β”‚         * log_kernel_param_dim_0  (log_kernel_param_dim_0) int64 16B 0 1
β”‚         * mean_dim_0              (mean_dim_0) int64 16B 0 1
β”‚       Data variables:
β”‚           lag                     (chain, draw) float64 8kB ...
β”‚           log_amp_scale           (chain, draw) float64 8kB ...
β”‚           log_drw_sigma           (chain, draw) float64 8kB ...
β”‚           log_kernel_param        (chain, draw, log_kernel_param_dim_0) float64 16kB ...
β”‚           mean                    (chain, draw, mean_dim_0) float64 16kB ...
β”‚       Attributes:
β”‚           created_at:                 2026-04-15T14:41:07.195973+00:00
β”‚           creation_library:           ArviZ
β”‚           creation_library_version:   1.0.0
β”‚           creation_library_language:  Python
β”‚           inference_library:          numpyro
β”‚           inference_library_version:  0.19.0
β”œβ”€β”€ Group: /sample_stats
β”‚       Dimensions:    (chain: 1, draw: 1000)
β”‚       Coordinates:
β”‚         * chain      (chain) int64 8B 0
β”‚         * draw       (draw) int64 8kB 0 1 2 3 4 5 6 7 ... 993 994 995 996 997 998 999
β”‚       Data variables:
β”‚           diverging  (chain, draw) bool 1kB ...
β”‚       Attributes:
β”‚           created_at:                 2026-04-15T14:41:07.203114+00:00
β”‚           creation_library:           ArviZ
β”‚           creation_library_version:   1.0.0
β”‚           creation_library_language:  Python
β”‚           inference_library:          numpyro
β”‚           inference_library_version:  0.19.0
└── Group: /observed_data
        Dimensions:   (gp_dim_0: 29)
        Coordinates:
          * gp_dim_0  (gp_dim_0) int64 232B 0 1 2 3 4 5 6 7 ... 21 22 23 24 25 26 27 28
        Data variables:
            gp        (gp_dim_0) float32 116B ...
        Attributes:
            created_at:                 2026-04-15T14:41:07.206919+00:00
            creation_library:           ArviZ
            creation_library_version:   1.0.0
            creation_library_language:  Python
            inference_library:          numpyro
            inference_library_version:  0.19.0

3. Plot some results#

[15]:
def plot_MCMC(
    X, y, yerr, mcmc, outfile=None, trace_vars=None, has_lag=True, zero_mean=True
):
    # Draw one sample from MCMC and extract parameters
    mcmc_samples = {
        k: v.stack(sample=("chain", "draw")).transpose("sample", ...).values
        for k, v in mcmc.posterior.data_vars.items()
    }
    sample_idx = 100  # Use first sample
    params_sample = {k: v[sample_idx] for k, v in mcmc_samples.items()}

    band_colors = {
        "u": "blue",
        "g": "green",
        "r": "red",
        "i": "orange",
        "z": "purple",
        "y": "brown",
    }

    # Create prediction grid (finer time resolution for smooth curves)
    pred_times = jnp.linspace(X[0].min(), X[0].max(), 200)
    pred_X = (pred_times, jnp.zeros_like(pred_times, dtype=int))  # Band 0

    # Make prediction for each band
    fig = plt.figure(figsize=(12, 6))
    ax = fig.add_gridspec(nrows=1, ncols=2, width_ratios=[1.75, 1], wspace=0.15)

    ax0 = fig.add_subplot(ax[0, 0])
    n_vars = len(trace_vars)
    ax_r = ax[0, 1].subgridspec(nrows=n_vars, ncols=2, hspace=0.4, wspace=0.4)
    ax1 = np.asarray(
        [[fig.add_subplot(ax_r[r, c]) for c in range(2)] for r in range(n_vars)],
        dtype=object,
    )

    n_band = len(np.unique(X[1]))
    k = Exp(scale=100.0, sigma=1.0)  # init params for k are not used
    m = MultiVarModel(X, y, yerr, k, n_band, has_lag=has_lag, zero_mean=zero_mean)
    m.sample(params_sample)
    model = m

    for i, b in enumerate("ri"):
        # Create prediction for this band
        pred_X_band = (pred_times, jnp.full_like(pred_times, i, dtype=int))
        mu, std = model.pred(params_sample, pred_X_band)

        # Plot observed data
        mask = X[1] == i

        # Plot prediction
        ax0.plot(pred_times, mu, "-", linewidth=1, alpha=0.5, color=band_colors[b])
        ax0.fill_between(
            pred_times, mu - std, mu + std, alpha=0.2, color=band_colors[b]
        )
        ax0.errorbar(
            X[0][mask],
            y[mask],
            yerr=yerr[mask],
            fmt=".",
            label=b,
            zorder=-100,
            color=band_colors[b],
        )

        ax0.set_xlabel("Time (days)", fontsize=12)
        ax0.set_ylabel("Scaled Magnitude", fontsize=12)
        ax0.tick_params(direction="in")
        ax0.invert_yaxis()
        ax0.legend(loc="upper left")

    for iv, var in enumerate(trace_vars):
        ax1[iv, 0].plot(np.array(mcmc.posterior[var]).flatten())
        ax1[iv, 1].hist(np.array(mcmc.posterior[var]).flatten())
        ax1[iv, 1].set_ylabel(var)
        ax1[iv, 1].yaxis.set_label_position("right")
        ax1[iv, 0].tick_params(direction="in")
        ax1[iv, 1].tick_params(direction="in")

    ax1[0, 0].set_title("Trace")
    ax1[0, 1].set_title("Posterior")

    if outfile != None:
        plt.savefig(outfile, bbox_inches="tight")
    else:
        plt.show()
[16]:
for i in range(5):
    obj = fitting_df.iloc[i]
    X, y, yerr = prepare_lightcurves(
        obj.diaObjectForcedSource["midpointMjdTai"],
        obj.diaObjectForcedSource["band"],
        obj.diaObjectForcedSource["psfMag"],
        obj.diaObjectForcedSource["psfMagErr"],
    )
    plot_MCMC(
        X,
        y,
        yerr,
        az.from_netcdf(f"mcmc_chains/{str(obj.diaObjectId)}.nc"),
        trace_vars=["log_drw_sigma", "log_amp_scale", "lag"],
    )
../_images/pre_executed_05_DP1_LSDB_28_0.png
../_images/pre_executed_05_DP1_LSDB_28_1.png
../_images/pre_executed_05_DP1_LSDB_28_2.png
../_images/pre_executed_05_DP1_LSDB_28_3.png
../_images/pre_executed_05_DP1_LSDB_28_4.png