Source code for TimeFrequencyWaveforms.code.TD_to_TFD_transform
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax.scipy.interpolate import RegularGridInterpolator
from typing import Tuple
from functools import partial
from TimeFrequencyWaveforms.code import utils
import WDM
[docs]
class Transformer:
"""
Description.
"""
[docs]
def __init__(self,
wdm : WDM.WDM.WDM_transform,
num_freq_points : int=100,
fdot_grid_spec : tuple=None,
fddot_grid_spec : tuple=None,
num_pixels : int=None) -> None:
"""
Parameters
----------
wdm : WDM.WDM
An instance of the WDM wavelet transform class.
num_freq_points : int
Number of frequency points in the interpolation grid. Optional.
fdot_grid_spec : tuple
Specification for fdot grid as (min, max, num_points). Optional.
Default is None, meaning no interpolation over fdot.
fddot_grid_spec : tuple
Specification for fddot grid as (min, max, num_points). Optional.
Default is None, meaning no interpolation over fddot.
num_pixels : int
Number of time pixels to interpolate. If None, then use all Nf
pixels. Optional.
Returns
-------
None
"""
self.wdm = wdm
self.num_freq_points = num_freq_points
self.n_ref = 2 * ( self.wdm.Nt // 4 )
self.m_ref = self.wdm.Nf // 2
self.f_ref = self.m_ref * self.wdm.dF
self.f_grid_spec = (0., 2*self.wdm.dF, self.num_freq_points)
self.fdot_grid_spec = fdot_grid_spec
self.fddot_grid_spec = fddot_grid_spec
if self.fddot_grid_spec is not None:
assert self.fdot_grid_spec is not None, \
"fdot_grid_spec must be provided if fddot_grid_spec is."
self.n_vals = jnp.arange(0, self.wdm.Nt)
self.m_vals = jnp.arange(0, self.wdm.Nf)
self.alt = (-1)**(self.n_vals[:,jnp.newaxis]+self.m_vals[jnp.newaxis,:])
self.num_pixels = num_pixels if num_pixels is not None else self.wdm.Nf
assert self.num_pixels <= self.wdm.Nf, \
"num_pixels cannot be larger than wdm.Nf."
self.m_pixel_range = self.m_ref - \
self.num_pixels//2 + \
jnp.arange(self.num_pixels)
if self.num_pixels == self.wdm.Nf:
assert jnp.all(jnp.equal(self.m_pixel_range, self.m_vals)), \
"I've messed up!"
self.mask_n_even = (self.n_vals % 2) == 0
self.mask_n_even = jnp.outer(self.mask_n_even,
jnp.ones(len(self.m_vals), dtype=bool))
self.grids = self.make_grids()
self.grid_shape = tuple(grid.shape[0] for grid in self.grids)
self.dim = len(self.grids)
(self.cnm_interp,
self.snm_interp,
self.chatnm_interp,
self.shatnm_interp) = self.make_interpolators()
[docs]
def make_grids(self) -> jnp.ndarray:
"""
Make the regular grid for interpolation.
Returns
-------
grids : jnp.ndarray
The grid points for interpolation, shape=(num_freq_points, ...).
"""
dim = 1
f_grid = jnp.linspace(*self.f_grid_spec)
grids = (f_grid, )
if self.fdot_grid_spec is not None:
dim += 1
fdot_grid = jnp.linspace(*self.fdot_grid_spec)
grids += (fdot_grid, )
if self.fddot_grid_spec is not None:
dim += 1
fddot_grid = jnp.linspace(*self.fddot_grid_spec)
grids += (fddot_grid, )
return grids
[docs]
def make_interpolators(self) -> Tuple[RegularGridInterpolator,
RegularGridInterpolator,
RegularGridInterpolator,
RegularGridInterpolator]:
"""
Create interpolators for the coefficients cnm, snm, chatnm and shatnm.
Returns
-------
interpolators : tuple
Four interpolators for cnm, snm, chatnm, shatnm.
"""
cos_data = jnp.zeros(self.grid_shape+(self.num_pixels,))
sin_data = jnp.zeros(self.grid_shape+(self.num_pixels,))
cos_hat_data = jnp.zeros(self.grid_shape+(self.num_pixels,))
sin_hat_data = jnp.zeros(self.grid_shape+(self.num_pixels,))
for i, f in enumerate(self.grids[0]+self.f_ref):
if self.dim == 1:
fdot = 0.0
fddot = 0.0
for m_, m in enumerate(self.m_pixel_range):
cos_data = cos_data.at[i,m_].set(utils.cnm(self.wdm,
self.n_ref, m, f, fdot, fddot))
sin_data = sin_data.at[i,m_].set(utils.snm(self.wdm,
self.n_ref, m, f, fdot, fddot))
cos_hat_data = cos_hat_data.at[i,m_].set(utils.chatnm(self.wdm,
self.n_ref, m, f, fdot, fddot))
sin_hat_data = sin_hat_data.at[i,m_].set(utils.shatnm(self.wdm,
self.n_ref, m, f, fdot, fddot))
elif self.dim == 2:
for j, fdot in enumerate(self.grids[1]):
fddot = 0.0
for m_, m in enumerate(self.m_pixel_range):
cos_data = cos_data.at[i,j,m_].set(utils.cnm(self.wdm,
self.n_ref, m, f, fdot, fddot))
sin_data = sin_data.at[i,j,m_].set(utils.snm(self.wdm,
self.n_ref, m, f, fdot, fddot))
cos_hat_data = cos_hat_data.at[i,j,m_].set(utils.chatnm(
self.wdm, self.n_ref, m,
f, fdot, fddot))
sin_hat_data = sin_hat_data.at[i,j,m_].set(utils.shatnm(
self.wdm, self.n_ref, m,
f, fdot, fddot))
elif self.dim == 3:
for j, fdot in enumerate(self.grids[1]):
for k, fddot in enumerate(self.grids[2]):
for m_, m in enumerate(self.m_pixel_range):
cos_data = cos_data.at[i,j,k,m_].set(utils.cnm(
self.wdm, self.n_ref, m,
f, fdot, fddot))
sin_data = sin_data.at[i,j,k,m_].set(utils.snm(
self.wdm, self.n_ref, m,
f, fdot, fddot))
cos_hat_data = cos_hat_data.at[i,j,k,m_].set(utils.chatnm(
self.wdm, self.n_ref, m,
f, fdot, fddot))
sin_hat_data = sin_hat_data.at[i,j,k,m_].set(utils.shatnm(
self.wdm, self.n_ref, m,
f, fdot, fddot))
cnm_interp = RegularGridInterpolator(self.grids, cos_data,
method='linear',
bounds_error=False,
fill_value=0.0)
snm_interp = RegularGridInterpolator(self.grids, sin_data,
method='linear',
bounds_error=False,
fill_value=0.0)
chatnm_interp = RegularGridInterpolator(self.grids, cos_hat_data,
method='linear',
bounds_error=False,
fill_value=0.0)
shatnm_interp = RegularGridInterpolator(self.grids, sin_hat_data,
method='linear',
bounds_error=False,
fill_value=0.0)
interpolators = (cnm_interp, snm_interp, chatnm_interp, shatnm_interp)
return interpolators
[docs]
@partial(jax.jit, static_argnums=0)
def coeffs(self,
F : jnp.ndarray,
fdot : jnp.ndarray=None,
fddot : jnp.ndarray=None) -> Tuple[jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray]:
"""
Interpolate to get the coefficients at given parameters.
Parameters
----------
F : jnp.ndarray
Frequencies, shape=(Nt,).
fdot : jnp.ndarray
Frequency derivatives, shape=(Nt,).
fddot : jnp.ndarray
Frequency second derivatives, shape=(Nt,).
Returns
-------
values : tuple
The quantities. Each of these are jnp.ndarray with shape=(Nt, Nf).
"""
query = jnp.array([x for x in [F, fdot, fddot] if x is not None]).T
CNM = jnp.zeros((self.wdm.Nt, self.wdm.Nf))
SNM = jnp.zeros((self.wdm.Nt, self.wdm.Nf))
CHATNM = jnp.zeros((self.wdm.Nt, self.wdm.Nf))
SHATNM = jnp.zeros((self.wdm.Nt, self.wdm.Nf))
CNM = CNM.at[:,self.m_pixel_range].set(self.cnm_interp(query))
SNM = SNM.at[:,self.m_pixel_range].set(self.snm_interp(query))
CHATNM = CHATNM.at[:,self.m_pixel_range].set(self.chatnm_interp(query))
SHATNM = SHATNM.at[:,self.m_pixel_range].set(self.shatnm_interp(query))
#CNM = self.cnm_interp(query)
#SNM = self.snm_interp(query)
#CHATNM = self.chatnm_interp(query)
#SHATNM = self.shatnm_interp(query)
CNM_shifted = jnp.where(self.mask_n_even, CNM, CHATNM)
SNM_shifted = jnp.where(self.mask_n_even, SNM, SHATNM)
CHATNM_shifted = jnp.where(self.mask_n_even, CHATNM, CNM)
SHATNM_shifted = jnp.where(self.mask_n_even, SHATNM, SNM)
values = (CNM_shifted, SNM_shifted, CHATNM_shifted, SHATNM_shifted)
return values
[docs]
@partial(jax.jit, static_argnums=0)
def cnm_snm(self,
A_n : jnp.ndarray,
phi_n : jnp.ndarray,
f_n : jnp.ndarray,
fdot_n : jnp.ndarray=None,
fddot_n : jnp.ndarray=None) -> Tuple[jnp.ndarray,
jnp.ndarray,
jnp.ndarray,
jnp.ndarray]:
r"""
Transformer.
Parameters
----------
A_n : jnp.ndarray
The waveform amplitude :math:`A(t_n)` evaluated at the sparse
wavelet times :math:`t_n = n \Delta T`.
phi_n : jnp.ndarray
The waveform phase :math:`\Phi(t_n)` [rad] evaluated at the sparse
wavelet times :math:`t_n = n \Delta T`.
f_n : jnp.ndarray
The waveform frequency :math:`f(t_n)` [Hz] evaluated at the sparse
wavelet times :math:`t_n = n \Delta T`.
fdot_n : jnp.ndarray
The frequency derivative :math:`\dot{f}(t_n)` [Hz/s]
evaluated at the sparse wavelet times :math:`t_n = n \Delta T`.
fddot : jnp.ndarray
The frequency second derivative :math:`\ddot{f}(t_n)` [Hz/s]
evaluated at the sparse wavelet times :math:`t_n = n \Delta T`.
Returns
-------
CNM, SNM, CHATNM, SHATNM
"""
F_n_minus_f_ref = (f_n-self.f_ref)%(2*self.wdm.dF)
z_n = jnp.floor((f_n-self.f_ref)/(2*self.wdm.dF)).astype(int)
CNM, SNM, CHATNM, SHATNM = self.coeffs(F_n_minus_f_ref,
fdot_n, fddot_n)
return CNM, SNM, CHATNM, SHATNM
[docs]
@partial(jax.jit, static_argnums=0)
def transform(self,
A_n : jnp.ndarray,
phi_n : jnp.ndarray,
f_n : jnp.ndarray,
fdot_n : jnp.ndarray=None,
fddot_n : jnp.ndarray=None) -> jnp.ndarray:
r"""
Transformer.
Parameters
----------
A_n : jnp.ndarray
The waveform amplitude :math:`A(t_n)` evaluated at the sparse
wavelet times :math:`t_n = n \Delta T`.
phi_n : jnp.ndarray
The waveform phase :math:`\Phi(t_n)` [rad] evaluated at the sparse
wavelet times :math:`t_n = n \Delta T`.
f_n : jnp.ndarray
The waveform frequency :math:`f(t_n)` [Hz] evaluated at the sparse
wavelet times :math:`t_n = n \Delta T`.
fdot_n : jnp.ndarray
The frequency derivative :math:`\dot{f}(t_n)` [Hz/s]
evaluated at the sparse wavelet times :math:`t_n = n \Delta T`.
fddot : jnp.ndarray
The frequency second derivative :math:`\ddot{f}(t_n)` [Hz/s]
evaluated at the sparse wavelet times :math:`t_n = n \Delta T`.
Returns
-------
wnm : jnp.ndarray
The wavelet coefficients of the waveform; the real part if the plus
polarisation, the imaginary part is the cross polarisation.
Array shape=(self.wdm.Nt, self.wdm.Nf), dtype=complex.
"""
F_n_minus_f_ref = (f_n-self.f_ref)%(2*self.wdm.dF)
z_n = jnp.floor((f_n-self.f_ref)/(2*self.wdm.dF)).astype(int)
CNM, SNM, CHATNM, SHATNM = self.coeffs(F_n_minus_f_ref,
fdot_n, fddot_n)
cnm = 0.5 * ( utils.row_roll(CNM, +2*z_n) +
utils.row_roll(CNM, -2*z_n) +
self.alt * utils.row_roll(SHATNM, +2*z_n) -
self.alt * utils.row_roll(SHATNM, -2*z_n) )
snm = 0.5 * ( utils.row_roll(SNM, +2*z_n) +
utils.row_roll(SNM, -2*z_n) -
self.alt * utils.row_roll(CHATNM, +2*z_n) +
self.alt * utils.row_roll(CHATNM, -2*z_n) )
wnm = ( A_n[:,jnp.newaxis] * \
jnp.exp(1j*phi_n[:,jnp.newaxis]) * \
( cnm + (1j) * snm ) )
return wnm