Source code for eztaox.ts_utils

"""Utility functions for time series processing."""

import jax
import jax.numpy as jnp
from numpy.typing import NDArray
from tinygp.helpers import JAXArray


@jax.jit
def _get_nearest_idx(t_in, x) -> int:
    """Get the index of the nearest value in `t_in` to `x`.

    Args:
        t_in (JAXArray): Array of time values.
        x (float): The value to find the nearest index for.
    """
    return jnp.argmin(jnp.abs(t_in - x))


[docs] def down_sample_by_time(t_in, t_out) -> JAXArray: """Downsample `t_in` to match the time points in `t_out`. Args: t_in (JAXArray): Array of time values to be downsampled. t_out (JAXArray): Array of target time values. Returns: JAXArray: Downsampled array of time values. """ return t_in[jax.vmap(_get_nearest_idx, in_axes=(None, 0))(t_in, t_out)]
[docs] def formatlc( ts: dict[str, NDArray | JAXArray], ys: dict[str, NDArray | JAXArray], yerrs: dict[str, NDArray | JAXArray], band_order: dict[str, int], ) -> tuple[tuple[JAXArray, JAXArray], JAXArray, JAXArray]: """Transform light curves in dictionary to EzTaoX friendly format. Args: ts (dict[str, NDArray | JAXArray]): Time stamps for observation in each band. ys (dict[str, NDArray | JAXArray]): Observed values in each band. yerrs (dict[str, NDArray | JAXArray]): Uncertainties in observed values for each band. band_order (dict[str, int]): Mapping of band names to band indices. Returns: tuple[tuple[JAXArray, JAXArray], JAXArray, JAXArray]: Light curves formatted as ((time stamps, band indices), observed values, uncertainties). """ band_keys = band_order.keys() tss = jnp.concat([ts[key] for key in band_keys]) yss = jnp.concat([ys[key] for key in band_keys]) yerrss = jnp.concat([yerrs[key] for key in band_keys]) band_idxs = jnp.concat( [jnp.ones(len(ts[x])) * band_order[x] for x in band_keys] ).astype(int) return (tss, band_idxs), yss, yerrss
[docs] def add_noise(y: JAXArray, yerr: JAXArray, key: jax.random.PRNGKey) -> JAXArray: """Add Gaussian noise to a time series given measurement uncertainties. JAX-compatible (works with jit/vmap). Args: y (JAXArray): Input values to which noise will be added. yerr (JAXArray): Associated errors for the input values. key (jax.random.PRNGKey): Pseudorandom number generator key used to draw the noise samples. Returns: JAXArray: Array of the same shape as ``y`` with additive Gaussian noise applied. """ y = jnp.asarray(y) yerr = jnp.asarray(yerr) # JAX needs an explicit PRNG key noise = jax.random.normal(key, shape=y.shape, dtype=y.dtype) * yerr return y + noise
[docs] def merge_two_sorted_argsort(a, b): """Merge two sorted arrays and return the argsort permutation. Args: a (JAXArray): First sorted array. b (JAXArray): Second sorted array. Returns: JAXArray: Indices that would sort the concatenation of ``a`` and ``b``. """ a = jnp.asarray(a) b = jnp.asarray(b) n = a.shape[0] m = b.shape[0] N = n + m perm = jnp.empty((N,), dtype=jnp.int32) def body(k, carry): i, j, perm = carry a_valid = i < n b_valid = j < m ai = jnp.where(a_valid, a[i], 0) bj = jnp.where(b_valid, b[j], 0) take_a = a_valid & (~b_valid | (ai <= bj)) idx = jnp.where(take_a, i, n + j) perm = perm.at[k].set(idx) i = i + take_a.astype(jnp.int32) j = j + (~take_a).astype(jnp.int32) return i, j, perm _, _, perm = jax.lax.fori_loop(0, N, body, (jnp.int32(0), jnp.int32(0), perm)) return perm
[docs] def merge_many_sorted_argsort(arrays): """Merge multiple sorted arrays and return the argsort permutation. Args: arrays (list[JAXArray]): List of sorted arrays to merge. Returns: JAXArray: Indices that would sort the concatenation of all input arrays. """ arrays = [jnp.asarray(a) for a in arrays] if len(arrays) == 0: return jnp.array([], dtype=jnp.int32) lengths = [a.shape[0] for a in arrays] offsets = [] s = 0 for n in lengths: offsets.append(s) s += n items = [ (a, jnp.arange(a.shape[0], dtype=jnp.int32) + off) for a, off in zip(arrays, offsets, strict=True) ] while len(items) > 1: new_items = [] for i in range(0, len(items), 2): if i + 1 == len(items): new_items.append(items[i]) else: av, ai = items[i] bv, bi = items[i + 1] p = merge_two_sorted_argsort(av, bv) new_items.append( ( jnp.concatenate([av, bv])[p], jnp.concatenate([ai, bi])[p], ) ) items = new_items return items[0][1]
@jax.jit
[docs] def merge_sort(*arrays) -> JAXArray: """Merge multiple sorted arrays and return the argsort permutation. Args: *arrays: Variable number of sorted arrays to merge. Returns: JAXArray: Indices that would sort the concatenation of all input arrays. """ return merge_many_sorted_argsort(list(arrays))