Source code for eztaox.kernels.transfer_function

"""Transfer functions"""

from __future__ import annotations

from abc import abstractmethod

import equinox as eqx
import jax
import tinygp
from jax import numpy as jnp
from tinygp.helpers import JAXArray

from eztaox.kernels.eqx_utils import find_param_by_name
from eztaox.kernels.quasisep import Quasisep


[docs] class TransferFunction(eqx.Module): """Base class for transfer functions :math:`\\Psi(\\Delta t)`.""" width: float shift: JAXArray | float = eqx.field(default_factory=lambda: jnp.zeros(())) @abstractmethod
[docs] def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: """Evaluate the transfer function at two points.""" del X1, X2 raise NotImplementedError
[docs] class GaussianTransferFunction(TransferFunction): """Gaussian transfer function: :math:`\\propto e^{-((\\Delta t-\\Delta t_0)/w)^2}`. where :math:`\\Delta t_0=\\mathrm{shift}`. The unity-normalization coefficient is: :math:`\\frac{1}{\\sqrt{\\pi}w}`. """
[docs] def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: """Evaluate the normalized transfer function at two points. Normalized so that :math:`\\int_{-\\infty}^{\\infty}\\Psi(\\Delta t)\\,d\\Delta t=1`. """ dt = X2 - X1 - self.shift norm = jnp.sqrt(jnp.pi) * self.width return jnp.exp(-jnp.square(dt / self.width)) / norm
[docs] class ExponentialTransferFunction(TransferFunction): """Exponential transfer function: :math:`\\propto e^{-|\\Delta t-\\Delta t_0|/w}`. where :math:`\\Delta t_0=\\mathrm{shift}`. The unity-normalization coefficient is: :math:`\\frac{1}{2w}`. """
[docs] def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: """Evaluate the normalized transfer function at two points. Normalized so that :math:`\\int_{-\\infty}^{\\infty}\\Psi(\\Delta t)\\,d\\Delta t=1`. """ dt = X2 - X1 - self.shift norm = 2.0 * self.width return jnp.exp(-jnp.abs(dt) / self.width) / norm
[docs] class CausalGaussianTransferFunction(TransferFunction): """Causal Gaussian: :math:`\\propto e^{-((\\Delta t-\\Delta t_0)/w)^2},\\Delta t\\ge0`. where :math:`\\Delta t_0=\\mathrm{shift}`. The unity-normalization coefficient is: :math:`\\left[\\frac{\\sqrt{\\pi}}{2}w\\left(1+\\mathrm{erf}(\\mathrm{shift}/w)\\right)\\right]^{-1}`. """ # noqa: E501
[docs] def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: """Evaluate the normalized transfer function at two points. Normalized so that :math:`\\int_{-\\infty}^{\\infty}\\Psi(\\Delta t)\\,d\\Delta t=1` for any shift. """ ds = X2 - X1 dt = ds - self.shift norm = ( jnp.sqrt(jnp.pi) / 2 * self.width * (1 + jax.scipy.special.erf(self.shift / self.width)) ) return jnp.where(ds >= 0, jnp.exp(-jnp.square(dt / self.width)) / norm, 0.0)
[docs] class CausalExponentialTransferFunction(TransferFunction): """Causal exponential: :math:`\\propto e^{-(\\Delta t-\\Delta t_0)/w},\\Delta t\\ge\\Delta t_0`. where :math:`\\Delta t_0=\\mathrm{shift}`. Defined for :math:`\\Delta t\\ge\\Delta t_0`, zero otherwise. The unity-normalization coefficient is: :math:`\\frac{1}{w}`. """ # noqa: E501
[docs] def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray: """Evaluate the normalized transfer function at two points. Normalized so that :math:`\\int_{-\\infty}^{\\infty}\\Psi(\\Delta t)\\,d\\Delta t=1` for any shift. """ dt = X2 - X1 - self.shift return jnp.where(dt >= 0, jnp.exp(-dt / self.width) / self.width, 0.0)
[docs] class ConvolvedKernel(tinygp.kernels.Kernel): """Kernel convolved with a transfer function via FFT. Computes the convolved kernel using the Wiener-Khinchin relation: :math:`S_{\\mathrm{conv}}(f)=S_{\\mathrm{base}}(f)\\,|\\hat{\\Psi}(f)|^2` :math:`k_{\\mathrm{conv}}(\\tau)=\\mathrm{IFFT}[S_{\\mathrm{conv}}](\\tau)` where :math:`\\hat{\\Psi}` is the Fourier transform of the transfer function and :math:`S_{\\mathrm{base}}` is the power spectral density of the base kernel. """ # We actually need .power(), so it could be extended to "direct" kernels base_kernel: Quasisep transfer_function: TransferFunction n_grid: int = eqx.field(static=True) truncation_factor: float = eqx.field(static=True, default=6.0)
[docs] def coord_to_sortable(self, X) -> JAXArray: """Extract the time-sortable component of the coordinates.""" return X[0]
@property def _half_width(self): """Half-width of integration grid around center.""" scales = find_param_by_name(self.base_kernel, "scale") scale = sum(scales) / len(scales) width = self.transfer_function.width return (scale + width) * self.truncation_factor @property def _center(self): """Center of integration grid.""" return self.transfer_function.shift
[docs] def evaluate(self, X1, X2) -> JAXArray: """Evaluate the transfer function at two points.""" tau = jnp.abs(X1 - X2) hw = self._half_width center = self._center n = self.n_grid # Uniform grid covering the TF support with zero-padding (2× support) grid_len = 4 * hw ds = grid_len / n s_grid = center - 2 * hw + jnp.arange(n) * ds # Evaluate Ψ on the grid zero = jnp.zeros(n) psi_vals = self.transfer_function.evaluate(zero, s_grid) # FFT-based computation: S_conv(f) = S_base(f) × |Ψ̂(f)|² psi_fft = jnp.fft.rfft(psi_vals) freqs = jnp.fft.rfftfreq(n, d=ds) psd_base = self.base_kernel.power(freqs) psd_conv = psd_base * jnp.abs(psi_fft) ** 2 # IFFT → k_conv on uniform lag grid k_conv = ds * jnp.fft.irfft(psd_conv, n=n) # Interpolate at desired lag (first half = non-negative lags) n_half = n // 2 + 1 tau_grid = jnp.arange(n_half) * ds return jnp.interp(tau, tau_grid, k_conv[:n_half])