"""Utility functions for time series processing."""
import jax
import jax.numpy as jnp
from numpy.typing import NDArray
from tinygp.helpers import JAXArray
@jax.jit
def _get_nearest_idx(t_in, x) -> int:
"""Get the index of the nearest value in `t_in` to `x`.
Args:
t_in (JAXArray): Array of time values.
x (float): The value to find the nearest index for.
"""
return jnp.argmin(jnp.abs(t_in - x))
[docs]
def down_sample_by_time(t_in, t_out) -> JAXArray:
"""Downsample `t_in` to match the time points in `t_out`.
Args:
t_in (JAXArray): Array of time values to be downsampled.
t_out (JAXArray): Array of target time values.
Returns:
JAXArray: Downsampled array of time values.
"""
return t_in[jax.vmap(_get_nearest_idx, in_axes=(None, 0))(t_in, t_out)]
[docs]
def add_noise(y: JAXArray, yerr: JAXArray, key: jax.random.PRNGKey) -> JAXArray:
"""Add Gaussian noise to a time series given measurement uncertainties.
JAX-compatible (works with jit/vmap).
Args:
y (JAXArray): Input values to which noise will be added.
yerr (JAXArray): Associated errors for the input values.
key (jax.random.PRNGKey): Pseudorandom number generator key used to
draw the noise samples.
Returns:
JAXArray: Array of the same shape as ``y`` with additive Gaussian noise applied.
"""
y = jnp.asarray(y)
yerr = jnp.asarray(yerr)
# JAX needs an explicit PRNG key
noise = jax.random.normal(key, shape=y.shape, dtype=y.dtype) * yerr
return y + noise
[docs]
def merge_two_sorted_argsort(a, b):
"""Merge two sorted arrays and return the argsort permutation.
Args:
a (JAXArray): First sorted array.
b (JAXArray): Second sorted array.
Returns:
JAXArray: Indices that would sort the concatenation of ``a`` and ``b``.
"""
a = jnp.asarray(a)
b = jnp.asarray(b)
n = a.shape[0]
m = b.shape[0]
N = n + m
perm = jnp.empty((N,), dtype=jnp.int32)
def body(k, carry):
i, j, perm = carry
a_valid = i < n
b_valid = j < m
ai = jnp.where(a_valid, a[i], 0)
bj = jnp.where(b_valid, b[j], 0)
take_a = a_valid & (~b_valid | (ai <= bj))
idx = jnp.where(take_a, i, n + j)
perm = perm.at[k].set(idx)
i = i + take_a.astype(jnp.int32)
j = j + (~take_a).astype(jnp.int32)
return i, j, perm
_, _, perm = jax.lax.fori_loop(0, N, body, (jnp.int32(0), jnp.int32(0), perm))
return perm
[docs]
def merge_many_sorted_argsort(arrays):
"""Merge multiple sorted arrays and return the argsort permutation.
Args:
arrays (list[JAXArray]): List of sorted arrays to merge.
Returns:
JAXArray: Indices that would sort the concatenation of all input arrays.
"""
arrays = [jnp.asarray(a) for a in arrays]
if len(arrays) == 0:
return jnp.array([], dtype=jnp.int32)
lengths = [a.shape[0] for a in arrays]
offsets = []
s = 0
for n in lengths:
offsets.append(s)
s += n
items = [
(a, jnp.arange(a.shape[0], dtype=jnp.int32) + off)
for a, off in zip(arrays, offsets, strict=True)
]
while len(items) > 1:
new_items = []
for i in range(0, len(items), 2):
if i + 1 == len(items):
new_items.append(items[i])
else:
av, ai = items[i]
bv, bi = items[i + 1]
p = merge_two_sorted_argsort(av, bv)
new_items.append(
(
jnp.concatenate([av, bv])[p],
jnp.concatenate([ai, bi])[p],
)
)
items = new_items
return items[0][1]
@jax.jit
[docs]
def merge_sort(*arrays) -> JAXArray:
"""Merge multiple sorted arrays and return the argsort permutation.
Args:
*arrays: Variable number of sorted arrays to merge.
Returns:
JAXArray: Indices that would sort the concatenation of all input arrays.
"""
return merge_many_sorted_argsort(list(arrays))