import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from functools import partial
import WDM
[docs]
@partial(jax.jit, static_argnums=0)
def Xn(wdm : WDM.WDM.WDM_transform,
n : int,
m : int,
f : float,
fdot : float=0.0,
fddot : float=0.0) -> jnp.ndarray:
r"""
Compute the phase term :math:`X_n(t)` at the sample times stored in the wdm
object.
.. math::
X_n(f, \dot{f}, \ldots) = 2\pi (t-n\Delta T) f + \
\pi (t-n\Delta T)^2 \dot{f} + \
(1/3) \pi (t-n\Delta T)^3 \ddot{f} .
Parameters
----------
wdm : WDM.WDM.WDM_transform
An instance of the WDM wavelet transform class.
n : int
Wavelet time index.
m : int
Wavelet frequency index.
f : float
Frequency at which to evaluate the phase term [Hz].
fdot : float
Frequency derivative [Hz/s]. Optional.
fddot : float
Second frequency derivative [Hz/s/s]. Optional.
Returns
-------
Xn : jnp.ndarray
The phase term, shape=(N,).
"""
Xn = 2. * jnp.pi * (wdm.times-n*wdm.dT) * f + \
jnp.pi * (wdm.times-n*wdm.dT)**2 * fdot + \
(1./.3) * jnp.pi * (wdm.times-n*wdm.dT)**3 * fddot
return Xn
[docs]
def cnm(wdm : WDM.WDM.WDM_transform,
n : int,
m : int,
f : float,
fdot : float=0.0,
fddot : float=0.0) -> jnp.ndarray:
r"""
Compute the coefficient :math:`c_{nm}(f,\dot{f},\ldots)`.
.. math::
c_{nm}(f,\dot{f},\ldots) = \int\mathrm{d}t\; \
\cos X_n(f,\dot{f},\ldots) g_{nm}(t).
Parameters
----------
wdm : WDM.WDM.WDM_transform
An instance of the WDM wavelet transform class.
n : int
Wavelet time index.
m : int
Wavelet frequency index.
f : float
Frequency at which to evaluate the phase term [Hz].
fdot : float
Frequency derivative [Hz/s]. Optional.
fddot : float
Second frequency derivative [Hz/s/s]. Optional.
Returns
-------
c : jnp.ndarray
The phase term, shape=(N,).
"""
g = wdm.gnm(n, m)
X = Xn(wdm, n, m, f, fdot=fdot, fddot=fddot)
c = wdm.dt * jnp.sum(jnp.cos(X)*g)
return c
[docs]
def snm(wdm : WDM.WDM.WDM_transform,
n : int,
m : int,
f : float,
fdot : float=0.0,
fddot : float=0.0) -> jnp.ndarray:
r"""
Compute the coefficient :math:`s_{nm}(f,\dot{f},\ldots)`.
.. math::
s_{nm}(f,\dot{f},\ldots) = \int\mathrm{d}t\; \
\sin X_n(f,\dot{f},\ldots) g_{nm}(t).
Parameters
----------
wdm : WDM.WDM.WDM_transform
An instance of the WDM wavelet transform class.
n : int
Wavelet time index.
m : int
Wavelet frequency index.
f : float
Frequency at which to evaluate the phase term [Hz].
fdot : float
Frequency derivative [Hz/s]. Optional.
fddot : float
Second frequency derivative [Hz/s/s]. Optional.
Returns
-------
s : jnp.ndarray
The phase term, shape=(N,).
"""
g = wdm.gnm(n, m)
X = Xn(wdm, n, m, f, fdot=fdot, fddot=fddot)
s = wdm.dt * jnp.sum(jnp.sin(X)*g)
return s
[docs]
def chatnm(wdm : WDM.WDM.WDM_transform,
n : int,
m : int,
f : float,
fdot : float=0.0,
fddot : float=0.0) -> jnp.ndarray:
r"""
Compute the coefficient :math:`\hat{c}_{nm}(f,\dot{f},\ldots)`.
.. math::
\hat{c}_{nm}(f,\dot{f},\ldots) = \int\mathrm{d}t\; \
\cos X_n(f,\dot{f},\ldots) \hat{g}_{nm}(t).
Parameters
----------
wdm : WDM.WDM.WDM_transform
An instance of the WDM wavelet transform class.
n : int
Wavelet time index.
m : int
Wavelet frequency index.
f : float
Frequency at which to evaluate the phase term [Hz].
fdot : float
Frequency derivative [Hz/s]. Optional.
fddot : float
Second frequency derivative [Hz/s/s]. Optional.
Returns
-------
chat : jnp.ndarray
The phase term, shape=(N,).
"""
ghat = wdm.gnm_dual(n, m)
X = Xn(wdm, n, m, f, fdot=fdot, fddot=fddot)
chat = wdm.dt * jnp.sum(jnp.cos(X)*ghat)
return chat
[docs]
def shatnm(wdm : WDM.WDM.WDM_transform,
n : int,
m : int,
f : float,
fdot : float=0.0,
fddot : float=0.0) -> jnp.ndarray:
r"""
Compute the coefficient :math:`\hat{s}_{nm}(f,\dot{f},\ldots)`.
.. math::
\hat{s}_{nm}(f,\dot{f},\ldots) = \int\mathrm{d}t\; \
\sin X_n(f,\dot{f},\ldots) \hat{g}_{nm}(t).
Parameters
----------
wdm : WDM.WDM.WDM_transform
An instance of the WDM wavelet transform class.
n : int
Wavelet time index.
m : int
Wavelet frequency index.
f : float
Frequency at which to evaluate the phase term [Hz].
fdot : float
Frequency derivative [Hz/s]. Optional.
fddot : float
Second frequency derivative [Hz/s/s]. Optional.
Returns
-------
shat : jnp.ndarray
The phase term, shape=(N,).
"""
ghat = wdm.gnm_dual(n, m)
X = Xn(wdm, n, m, f, fdot=fdot, fddot=fddot)
shat = wdm.dt * jnp.sum(jnp.sin(X)*ghat)
return shat
[docs]
@jax.jit
def row_roll(A: jnp.ndarray,
shifts: jnp.ndarray) -> jnp.ndarray:
"""
Roll each row of a 2D array by a different integer amount along axis 1.
Given input array A of shape (N, M) and a vector of integer shifts of shape
(N,), this function circularly shifts (or rolls) the elements of row i of A
by shift[i] positions along the second axis.
Parameters
----------
A : jnp.ndarray
Input array, shape=(N, M).
shifts : jnp.ndarray
Integer array of shifts, shape=(N,), dtype=int.
Returns
-------
B : jnp.ndarray
Output array, shape=(N, M).
"""
N, M = A.shape
cols = jnp.arange(M)
idx = (cols[None, :] - shifts[:, None]) % M
B = A[jnp.arange(N)[:, None], idx]
return B