Source code for eztaox.kernels.transfer_function
"""Transfer functions"""
from __future__ import annotations
from abc import abstractmethod
import equinox as eqx
import jax
import tinygp
from jax import numpy as jnp
from tinygp.helpers import JAXArray
from eztaox.kernels.eqx_utils import find_param_by_name
from eztaox.kernels.quasisep import Quasisep
[docs]
class TransferFunction(eqx.Module):
"""Base class for transfer functions :math:`\\Psi(\\Delta t)`."""
width: float
shift: JAXArray | float = eqx.field(default_factory=lambda: jnp.zeros(()))
@abstractmethod
[docs]
def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
"""Evaluate the transfer function at two points."""
del X1, X2
raise NotImplementedError
[docs]
class GaussianTransferFunction(TransferFunction):
"""Gaussian transfer function: :math:`\\propto e^{-((\\Delta t-\\Delta t_0)/w)^2}`.
where :math:`\\Delta t_0=\\mathrm{shift}`.
The unity-normalization coefficient is:
:math:`\\frac{1}{\\sqrt{\\pi}w}`.
"""
[docs]
def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
"""Evaluate the normalized transfer function at two points.
Normalized so that
:math:`\\int_{-\\infty}^{\\infty}\\Psi(\\Delta t)\\,d\\Delta
t=1`.
"""
dt = X2 - X1 - self.shift
norm = jnp.sqrt(jnp.pi) * self.width
return jnp.exp(-jnp.square(dt / self.width)) / norm
[docs]
class ExponentialTransferFunction(TransferFunction):
"""Exponential transfer function: :math:`\\propto e^{-|\\Delta t-\\Delta t_0|/w}`.
where :math:`\\Delta t_0=\\mathrm{shift}`.
The unity-normalization coefficient is:
:math:`\\frac{1}{2w}`.
"""
[docs]
def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
"""Evaluate the normalized transfer function at two points.
Normalized so that
:math:`\\int_{-\\infty}^{\\infty}\\Psi(\\Delta t)\\,d\\Delta
t=1`.
"""
dt = X2 - X1 - self.shift
norm = 2.0 * self.width
return jnp.exp(-jnp.abs(dt) / self.width) / norm
[docs]
class CausalGaussianTransferFunction(TransferFunction):
"""Causal Gaussian: :math:`\\propto e^{-((\\Delta t-\\Delta t_0)/w)^2},\\Delta t\\ge0`.
where :math:`\\Delta t_0=\\mathrm{shift}`.
The unity-normalization coefficient is:
:math:`\\left[\\frac{\\sqrt{\\pi}}{2}w\\left(1+\\mathrm{erf}(\\mathrm{shift}/w)\\right)\\right]^{-1}`.
""" # noqa: E501
[docs]
def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
"""Evaluate the normalized transfer function at two points.
Normalized so that
:math:`\\int_{-\\infty}^{\\infty}\\Psi(\\Delta t)\\,d\\Delta
t=1` for any shift.
"""
ds = X2 - X1
dt = ds - self.shift
norm = (
jnp.sqrt(jnp.pi)
/ 2
* self.width
* (1 + jax.scipy.special.erf(self.shift / self.width))
)
return jnp.where(ds >= 0, jnp.exp(-jnp.square(dt / self.width)) / norm, 0.0)
[docs]
class CausalExponentialTransferFunction(TransferFunction):
"""Causal exponential: :math:`\\propto e^{-(\\Delta t-\\Delta t_0)/w},\\Delta t\\ge\\Delta t_0`.
where :math:`\\Delta t_0=\\mathrm{shift}`.
Defined for :math:`\\Delta t\\ge\\Delta t_0`, zero otherwise.
The unity-normalization coefficient is:
:math:`\\frac{1}{w}`.
""" # noqa: E501
[docs]
def evaluate(self, X1: JAXArray, X2: JAXArray) -> JAXArray:
"""Evaluate the normalized transfer function at two points.
Normalized so that
:math:`\\int_{-\\infty}^{\\infty}\\Psi(\\Delta t)\\,d\\Delta
t=1` for any shift.
"""
dt = X2 - X1 - self.shift
return jnp.where(dt >= 0, jnp.exp(-dt / self.width) / self.width, 0.0)
[docs]
class ConvolvedKernel(tinygp.kernels.Kernel):
"""Kernel convolved with a transfer function via FFT.
Computes the convolved kernel using the Wiener-Khinchin relation:
:math:`S_{\\mathrm{conv}}(f)=S_{\\mathrm{base}}(f)\\,|\\hat{\\Psi}(f)|^2`
:math:`k_{\\mathrm{conv}}(\\tau)=\\mathrm{IFFT}[S_{\\mathrm{conv}}](\\tau)`
where :math:`\\hat{\\Psi}` is the Fourier transform of the transfer function
and :math:`S_{\\mathrm{base}}` is the power spectral density of the base
kernel.
"""
# We actually need .power(), so it could be extended to "direct" kernels
base_kernel: Quasisep
transfer_function: TransferFunction
n_grid: int = eqx.field(static=True)
truncation_factor: float = eqx.field(static=True, default=6.0)
[docs]
def coord_to_sortable(self, X) -> JAXArray:
"""Extract the time-sortable component of the coordinates."""
return X[0]
@property
def _half_width(self):
"""Half-width of integration grid around center."""
scales = find_param_by_name(self.base_kernel, "scale")
scale = sum(scales) / len(scales)
width = self.transfer_function.width
return (scale + width) * self.truncation_factor
@property
def _center(self):
"""Center of integration grid."""
return self.transfer_function.shift
[docs]
def evaluate(self, X1, X2) -> JAXArray:
"""Evaluate the transfer function at two points."""
tau = jnp.abs(X1 - X2)
hw = self._half_width
center = self._center
n = self.n_grid
# Uniform grid covering the TF support with zero-padding (2× support)
grid_len = 4 * hw
ds = grid_len / n
s_grid = center - 2 * hw + jnp.arange(n) * ds
# Evaluate Ψ on the grid
zero = jnp.zeros(n)
psi_vals = self.transfer_function.evaluate(zero, s_grid)
# FFT-based computation: S_conv(f) = S_base(f) × |Ψ̂(f)|²
psi_fft = jnp.fft.rfft(psi_vals)
freqs = jnp.fft.rfftfreq(n, d=ds)
psd_base = self.base_kernel.power(freqs)
psd_conv = psd_base * jnp.abs(psi_fft) ** 2
# IFFT → k_conv on uniform lag grid
k_conv = ds * jnp.fft.irfft(psd_conv, n=n)
# Interpolate at desired lag (first half = non-negative lags)
n_half = n // 2 + 1
tau_grid = jnp.arange(n_half) * ds
return jnp.interp(tau, tau_grid, k_conv[:n_half])