Source code for WDM.code.discrete_wavelet_transform.WDM

import jax
import jax.numpy as jnp
from matplotlib.pylab import indices

from WDM.code.utils.Meyer import Meyer
from WDM.code.utils.utils import C_nm, overlapping_windows

from typing import Tuple
from functools import partial


[docs] class WDM_transform: r""" This class implements the WDM discrete wavelet transform. Attributes ---------- dt : float The cadence, or time step, of the time series, :math:`\delta t`. Equal to inverse of the sampling frequency. Nf : int Number of wavelet frequency bands, :math:`N_f`. Must be even. This controls the time/frequency resolution of the wavelets. N : int Length of the input time series, :math:`N`. Must be an even multiple of :math:`N_f`. Nt : int Number of wavelet time bands, :math:`N_t`. Equal to :math:`N/N_f`. Must be even. q : int Truncation parameter, :math:`q`. Formally the time domain wavelets have infinite extent but in practice are truncated at :math:`\pm q \Delta T`. This must be an integer in the range :math:`1 \leq q \leq N_t/2`. d : int Steepness parameter for the Meyer window transition. Must be a positive integer, :math:`d\geq 1`. A_frac : float Fraction of total bandwidth used for the flat-top response region. Must be in the range [0, 1]. B_frac : float Fraction of total bandwidth used for the transition region. This is set based on A_frac so :math:`2A_{\mathrm{frac}}+B_{\mathrm{frac}}=1`. A : float Half-width of the flat-top response region in angular frequency (radians per unit time), :math:`A`. Satisfies :math:`\Delta \Omega = 2A + B`. B : float Width of the transition region in angular frequency (radians per unit time), :math:`B`. Satisfies :math:`\Delta \Omega = 2A + B`. dF : float Frequency resolution of the wavelets, or the total wavelet frequency bandwidth :math:`\Delta F = \frac{\Delta \Omega}{2 \pi}`. dT : float Time resolution of the wavelets. Related to the frequency resolution by :math:`\Delta F \Delta T = \frac{1}{2}`. dOmega : float Angular Frequency resolution of the wavelets (radians per unit time), or total wavelet angular frequency bandwidth, :math:`\Delta \Omega = 2A+B`. T : float Total duration of the time series. Related to :math:`N` and :math:`\delta t` by :math:`T = N \delta t`. df : float The frequency resolution of the time series, :math:`\delta f = 1/T`. f_s : float Sampling frequency of the time series, :math:`f_s = \frac{1}{\delta t}`. f_Ny : float Nyquist frequency (i.e. maximum frequency) of the time series, :math:`f_{\rm Ny} = \frac{1}{2 \delta t}`. K : int Window length in samples, :math:`K = 2 q N_f`. By definition, this is always an even integer. times : jnp.ndarray The sample times of the time series, :math:`t_k = k \delta t` for :math:`k\in\{0,1,\ldots,N-1\}`. Array shape=(N,). freqs : jnp.ndarray The sample frequencies of the time series, :math:`f_k = k \delta f` for :math:`k\in\{-N/2,N/2+1,\ldots,N/2-1\}`. Array shape=(N,). Note, the zero-frequency component is in the center of the spectrum. Cnm : jnp.ndarray Coefficients :math:`C_{nm}` used for the wavelet transform. Equal to 1 if :math:`n+m` is even or :math:`i` if it's odd. Array shape=(N_t, N_f). calc_m0 : bool If this is set to False (default value) then the wavelet coefficients with :math:`m=0` are handled INCORRECTLY. This is faster. If these coefficients are needed the initialise the class with `calc_m0=True`. window_TD : jnp.ndarray The time-domain Meyer window function, :math:`\phi(t)`. Array shape=(N,). window_FD : jnp.ndarray The frequency-domain Meyer window function, :math:`\tilde{\Phi}(f)`. Array shape=(N,), dtype=complex. cached_Gnm_basis : jnp.ndarray The frequency-domain wavelet basis :math:`\tilde{G}_{nm}(f)`. Array shape=(N, Nt, Nf). cached_gnm_basis : jnp.ndarray The time-domain wavelet basis :math:`g_{nm}(t)`. Array shape=(N, Nt, Nf). jax_dtype : jnp.float64 Use jax.config.update("jax_enable_x64", True). jax_dtype_int : jnp.int64 Use jax.config.update("jax_enable_x64", True). """ def __init__(self, dt : float, Nf : int, N : int, q : int = 16, d : int = 4, A_frac : float = 0.25, calc_m0 : bool = False) -> None: r""" Initialize the WDM_transform. Parameters ---------- dt : float The time series cadence, or time step. Nf : int Number of frequency bands, controls the time/frequency resolution. N : int Length of the input time series. Must be an even multiple of Nf. q : int Truncation parameter. Integer :math:`1 \leq q \leq N_t/2`. Optional. d : int Steepness parameter for the transition. Optional. A_frac : float Bandwidth fraction of flat-top response. Optional. calc_m0 : bool If False, then the wavelet calculations for the :math:`m=0` temrs will be wrong; this has performance benefits. If True, then all calculations will be correct, but this may be slower. Optional. Returns ------- None """ self.dt = float(dt) self.Nf = int(Nf) self.N = int(N) self.q = int(q) self.A_frac = float(A_frac) self.d = int(d) self.calc_m0 = bool(calc_m0) self.validate_parameters() # Derived parameters self.times = jnp.arange(self.N) * self.dt self.freqs = jnp.fft.fftshift(jnp.fft.fftfreq(self.N, d=self.dt)) self.Nt = self.N // self.Nf self.T = self.N * self.dt self.df = 1. / self.T self.dF = 1. / ( 2. * self.dt * self.Nf ) self.dOmega = 2. * jnp.pi * self.dF self.dT = self.dt * self.Nf self.f_s = 1. / self.dt self.f_Ny = 0.5 / self.dt self.B_frac = 1. - 2. * self.A_frac self.A = self.A_frac * self.dOmega self.B = self.B_frac * self.dOmega self.K = 2 * self.q * self.Nf self.Cnm = jnp.array([[ C_nm(n, m) for m in range(self.Nf)] for n in range(self.Nt)]) self.window_FD = self.build_frequency_domain_window() self.window_TD = self.build_time_domain_window() self._cached_Gnm_basis = None self._cached_gnm_basis = None if jax.config.read("jax_enable_x64"): self.jax_dtype = jnp.float64 self.jax_dtype_int = jnp.int64 else: self.jax_dtype = jnp.float32 self.jax_dtype_int = jnp.int32
[docs] def validate_parameters(self) -> None: r""" Validate the parameters provided to the WDM_transform __init__ method. Raises an AssertionError if any parameters are invalid. Returns ------- None """ assert self.dt > 0, \ f"dt must be positive, got {self.dt=}." assert self.Nf > 0 and self.Nf % 2 == 0, \ f"Nf must be a positive even integer, got {self.Nf=}." assert self.N > 0 and self.N % 2 == 0, \ f"Nt must be a positive even integer, got {self.N=}." assert self.N % self.Nf == 0 and ( self.N // self.Nf ) % 2 == 0, \ f"N must be even multiple of Nf, got {self.N=}, {self.Nf=}." Nt = self.N // self.Nf assert self.q >= 1 and self.q <= Nt//2, \ f"q must be integer in range 1<=q<={Nt//2}, got {self.q=}." assert 0. < self.A_frac < 1., \ f"A_frac must be in range 0<A_frac<1, got {self.A_frac=}." assert self.d >= 1, \ f"d must be a positive integer, got {self.d=}."
[docs] def build_frequency_domain_window(self) -> jnp.ndarray: r""" Construct the frequency-domain window function :math:`\tilde{\Phi}(f)`. Note, the zero-frequency component is in the center of the spectrum. Returns ------- Phi : jnp.ndarray Array of shape (N,). Complex-valued frequency-domain window. """ Phi = Meyer(2.*jnp.pi*self.freqs, self.d, self.A, self.B) return jnp.sqrt(2.*jnp.pi) * Phi
[docs] def build_time_domain_window(self) -> jnp.ndarray: r""" Construct the time-domain window function :math:`\phi(t)`. This method builds the Meyer window in the frequency domain and applies an inverse FFT to obtain the corresponding time-domain window. Returns ------- phi : jnp.ndarray Array of shape (N,). Real-valued time-domain window. """ phi = jnp.fft.ifft(jnp.fft.ifftshift(self.window_FD)).real / self.dt return phi
[docs] @partial(jax.jit, static_argnums=0) def check_indices(self, n : jnp.ndarray, m : jnp.ndarray) -> bool: r""" Check if the wavelet indices are in the valid range. The `n` indices must satisfy :math:`0 \leq n < N_t` and the `m` indices must satisfy :math:`0 \leq m < N_f`. Parameters ---------- n : jnp.ndarray Array of n indices, dtype=int. Wavelet time index. m : jnp.ndarray Array of m indices, dtype=int. Wavelet frequency index. Returns ------- check : bool True if the all indices are valid, otherwise False. """ n = jnp.asarray(n, self.jax_dtype_int) m = jnp.asarray(m, self.jax_dtype_int) n_test = jnp.all(jnp.logical_and(n>=0, n<self.Nt)) m_test = jnp.all(jnp.logical_and(m>=0, m<self.Nf)) check = jnp.logical_and(n_test, m_test) return check
[docs] def wavelet_central_time_frequency(self, n : jnp.ndarray, m : jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: r""" Compute the central time :math:`t_{nm}= n \Delta t` and the central frequency :math:`f_{nm} = m \Delta f` of the wavelet :math:`g_{nm}(t)`. The case :math:`m=0` is special and is handled separately using .. math:: t_{n0} = 2n \Delta t , .. math:: f_{n0} = \begin{cases} 0 & \mathrm{if}\; n<N_t/2 \\ f_{\mathrm{Ny}} & \mathrm{if}\; n\geq N_t/2 \end{cases} . Parameters ---------- n : jnp.ndarray Wavelet time index, dtype=int, shape=(num_n,). m : jnp.ndarray Wavelet frequency index, dtype=int, shape=(num_m,). Returns ------- t_nm : jnp.ndarray Array of times, shape=(num_n, num_m). The wavelet central times. f_nm : jnp.ndarray Array of frequencies, shape=(num_n, num_m). The wavelet central frequencies. """ assert self.check_indices(n, m), f"Invalid indices: {n=} {m=}" return self.wavelet_central_time_frequency_compiled(n, m)
[docs] @partial(jax.jit, static_argnums=0) def wavelet_central_time_frequency_compiled(self, n : jnp.ndarray, m : jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Compiled part of wavelet_central_time_frequency method. Parameters ---------- n : jnp.ndarray Wavelet time index, dtype=int, shape=(num_n,). m : jnp.ndarray Wavelet frequency index, dtype=int, shape=(num_m,). Returns ------- t_nm : jnp.ndarray Array of times, shape=(num_n, num_m). The wavelet central times. f_nm : jnp.ndarray Array of frequencies, shape=(num_n, num_m). The wavelet central frequencies. """ n = jnp.asarray(n, self.jax_dtype_int) m = jnp.asarray(m, self.jax_dtype_int) n_col = n[:, None] # (len(n), 1) m_row = m[None, :] # (1, len(m)) mzero = (m_row == 0) t_nm = jnp.where(mzero, 2 * n_col * self.dT, n_col * self.dT) f_m0 = jnp.where(n_col < (self.Nt // 2), 0.0, self.f_Ny) f_nm = jnp.where(mzero, f_m0, m_row * self.dF) return t_nm, f_nm
[docs] def Gnm(self, n : int, m : int, freq : jnp.ndarray = None) -> jnp.ndarray: r""" Compute the frequency-domain representation of the wavelets, :math:`\tilde{G}_{nm}(f)`. This method computes the frequency-domain wavelet for a single choice of :math:`n` and :math:`m` using the expressions below. If you instead want to compute the full wavelet basis for all :math:`n` and :math:`m` efficiently, use the `Gnm_basis` method. For :math:`m>0`, the wavelet is given by .. math:: \tilde{G}_{nm}(f) = \frac{\exp(-2\pi i n f \Delta T)}{\sqrt{2}} \left( C_{nm}\tilde{\Phi}(f+m\Delta F) + C^*_{nm}\tilde{\Phi}(f-m\Delta F) \right) . For the special case :math:`m=0`, the wavelet is given by .. math:: \tilde{G}_{n0}(f) = \begin{cases} \exp(-4\pi i n f \Delta T) \tilde{\Phi}(f) & \mathrm{if}\; n<N_t/2 \\ \frac{1}{2} \exp(-4\pi i n f \Delta T) \left( \tilde{\Phi}(f-f_{\rm Ny}) + \tilde{\Phi}(f+f_{\rm Ny}) \right) & \mathrm{if}\; n\geq N_t/2 \end{cases} Parameters ---------- n : int Wavelet time index. m : int Wavelet frequency index. freq : jnp.ndarray Frequencies at which to evaluate the wavelet. If None, then defaults to self.freqs. Optional Returns ------- Gnm : jnp.ndarray Complex array shaped like freq. The frequency-domain wavelet. """ assert self.check_indices(n, m), f"Invalid indices: {n=} {m=}" k_vals = jnp.arange(self.N) if m > 0: Gnm = (1./jnp.sqrt(2.)) * \ jnp.exp(-1j*n*2.*jnp.pi*self.freqs*self.dT) * ( C_nm(n, m) * self.window_FD[(k_vals+m*self.Nt//2)%self.N] + jnp.conj(C_nm(n, m)) * self.window_FD[(k_vals-m*self.Nt//2)%self.N] ) else: if n < self.Nt // 2: # zero-frequency terms Gnm = jnp.exp(-1j*n*4.*jnp.pi*self.freqs*self.dT) * \ self.window_FD else: # Nyquist-frequency terms Gnm = 0.5 * jnp.exp(-1j*n*4.*jnp.pi*self.freqs*self.dT) * \ (self.window_FD[(k_vals+self.N//2)%self.N] + self.window_FD[(k_vals-self.N//2)%self.N]) return Gnm
[docs] @partial(jax.jit, static_argnums=0) def Gnm_basis(self) -> jnp.ndarray: r""" Efficient computation of frequency-domain wavelet basis :math:`\tilde{G}_{nm}(f)`. Instead of calling the functions for :math:`\tilde{G}_{nm}(f)` explicilty as is done in the `Gnm` method, this function shifts indices of `window_FD`. The result is cached to speed up subsequent calls. Returns ------- basis : jnp.ndarray Array of shape (N, Nt, Nf). The time-domain wavelet basis. The first axis is frequency, the second is the wavelet time index, and the third is the wavelet frequency index. """ if self._cached_Gnm_basis is not None: pass else: n_vals = jnp.arange(self.Nt) m_vals = jnp.arange(self.Nf) om = 2. * jnp.pi * self.freqs shift_up = (jnp.arange(self.N)[:,jnp.newaxis] + m_vals[jnp.newaxis,:]*self.Nt//2) shift_do = (jnp.arange(self.N)[:,jnp.newaxis] - m_vals[jnp.newaxis,:]*self.Nt//2) basis = (1./jnp.sqrt(2.)) * \ jnp.exp(-1j*n_vals[jnp.newaxis,:,jnp.newaxis]*\ om[:,jnp.newaxis,jnp.newaxis]*self.dT) * \ (self.Cnm[jnp.newaxis,:,:]*\ self.window_FD[shift_up%self.N][:,jnp.newaxis,:]+ jnp.conj(self.Cnm[jnp.newaxis,:,:])*\ self.window_FD[shift_do%self.N][:,jnp.newaxis,:]) if self.calc_m0: # overwrite m=0 terms for n<Nt/2 (zero-frequency terms) n_vals = jnp.arange(self.Nt//2) f0_term = jnp.exp(-2j*n_vals[jnp.newaxis,:] * \ om[:,jnp.newaxis]*self.dT) * \ self.window_FD[:,jnp.newaxis] basis = basis.at[:, n_vals, 0].set(f0_term) # overwrite m=0 terms for n>=Nt/2 (Nyquist-frequency terms) n_vals = jnp.arange(self.Nt//2, self.Nt) shift_up = (jnp.arange(self.N) + self.N//2) shift_do = (jnp.arange(self.N) - self.N//2) fNy_term = 0.5 * jnp.exp(-2j*n_vals[jnp.newaxis,:] * \ om[:,jnp.newaxis]*self.dT) * \ (self.window_FD[shift_up%self.N][:,jnp.newaxis] + self.window_FD[shift_do%self.N][:,jnp.newaxis]) basis = basis.at[:, n_vals, 0].set(fNy_term) self._cached_Gnm_basis = basis return self._cached_Gnm_basis
[docs] def gnm(self, n : int, m : int) -> jnp.ndarray: r""" Compute the time-domain representation of the wavelets, :math:`g_{nm}(t)`. This method computes the frequency-domain wavelets for a single choice of :math:`n` and :math:`m` and performs and inverse Fourier transform. If you instead want to compute the full wavelet basis for all :math:`n` and :math:`m` efficiently, use the `gnm_basis` method. Parameters ---------- n : int Wavelet time index. m : int Wavelet frequency index. Returns ------- gnm : jnp.ndarray Array shape (N,). The time-domain wavelet. """ assert self.check_indices(n, m), f"Invalid indices: {n=} {m=}" Gnm = self.Gnm(n, m) gnm = jnp.fft.ifft(jnp.fft.ifftshift(Gnm)).real / self.dt return gnm
[docs] @partial(jax.jit, static_argnums=0) def gnm_basis(self) -> jnp.ndarray: r""" Efficient computation of time-domain wavelet basis :math:`g_{nm}(f)`. Instead of calling the functions for :math:`\tilde{G}_{nm}(f)` and performing an inverse Fourier transform, as is done in the `gnm` method, this function shifts indices of `window_TD`. For :math:`m>0`, the wavelet is given by .. math:: g_{nm}(t) = \begin{cases} \sqrt{2} (-1)^{mn} \cos\left(\frac{\pi m t}{\Delta T}\right) \phi(t-n\Delta T) & \mathrm{if}\;n+m\;\mathrm{even} \\ \sqrt{2} \sin\left(\frac{\pi m t}{\Delta T}\right) \phi(t-n\Delta T) & \mathrm{if}\;n+m\;\mathrm{odd} \end{cases} . For the special case :math:`m=0`, the wavelet is given by .. math:: g_{n0}(t) = \begin{cases} \phi(t-2n\Delta T) & \mathrm{if}\;n<N_t/2 \\ \frac{1}{2} \exp(-4\pi i n f \Delta T) \left( \tilde{\Phi}(f-f_{\rm Ny}) + \tilde{\Phi}(f+f_{\rm Ny}) \right) & \mathrm{if}\; n\geq N_t/2 \end{cases}. The result is cached to speed up subsequent calls. Returns ------- basis : jnp.ndarray Array of shape (N, Nt, Nf). The time-domain wavelet basis. """ if self._cached_gnm_basis is not None: pass else: n_vals = jnp.arange(self.Nt) m_vals = jnp.arange(self.Nf) k_vals = jnp.arange(self.N) def temp_func(n, m): shift = ((n+m)%2) * jnp.pi/2. return jnp.sqrt(2.) * (-1)**(n*m) * \ jnp.cos(jnp.pi*m*k_vals/self.Nf-shift) * \ self.window_TD[(k_vals-n*self.Nf)%self.N] f_vmapped = jax.vmap(jax.vmap(temp_func, in_axes=(None, 0)), in_axes=(0, None)) basis = f_vmapped(n_vals, m_vals) basis = jnp.transpose(basis, (2, 0, 1)) if self.calc_m0: # overwrite m=0 terms for n<Nt/2 (zero-frequency terms) n_vals = jnp.arange(self.Nt//2) f0_term = self.window_TD[(k_vals[:,jnp.newaxis] -2*n_vals[jnp.newaxis,:]*self.Nf)%self.N] basis = basis.at[:, n_vals, 0].set(f0_term) # overwrite m=0 terms for n>=Nt/2 (Nyquist-frequency terms) n_vals = jnp.arange(self.Nt//2, self.Nt) def temp_func(n): return (-1)**(k_vals) * \ self.window_TD[(k_vals-2*n*self.Nf)%self.N] f_vmapped = jax.vmap(temp_func) fNy_term = f_vmapped(n_vals).T basis = basis.at[:, n_vals, 0].set(fNy_term) self._cached_gnm_basis = basis return self._cached_gnm_basis
[docs] @partial(jax.jit, static_argnums=0) def short_fft(self, x : jnp.ndarray) -> jnp.ndarray: r""" The windowed short FFT of the input. The input time series is split into :math:`N_t` overlapping segments each of length :math:`K` and with a hop interval of :math:`N_f` between their centres. Each of these segments is then windowed and FFT'd. .. math:: X_n[j] = \sum_{k=-K/2}^{K/2-1} \exp(2\pi i kj/K) x[nN_f+k] \phi[k] Parameters ---------- x : jnp.ndarray Array shape (N,). Input time series signal to be transformed. Returns ------- windowed_fft : jnp.ndarray Array shape shape (Nt, K). Short FFT of the input, :math:`X_n[j]`. """ x = jnp.asarray(x) assert x.shape == (self.N,), \ f"Input signal must have shape ({self.N},), got {x.shape=}" windowed_fft = overlapping_windows(x, self.K, self.Nt, self.Nf) k_vals = jnp.arange(-self.K//2, self.K//2) sign = (-1)**jnp.arange(self.K) windowed_fft *= self.window_TD[k_vals%self.N] windowed_fft = jnp.fft.ifft(windowed_fft, axis=-1) * self.K * sign return windowed_fft
[docs] @partial(jax.jit, static_argnums=0) def forward_transform_exact(self, x : jnp.ndarray) -> jnp.ndarray: r""" Perform the forward discrete wavelet transform. Transforms the input signal from the time domain into the time-frequency domain. This method computes the wavelet coefficients using the exact expression .. math:: w_{nm} = \delta t \sum_{k=0}^{N-1} g_{nm}[k] x[k] , where the sum is over the whole time-domain signal (no truncation). This method is slow but exact. Parameters ---------- x : jnp.ndarray Array shape shape (N,). Input time-domain signal to be transformed. Returns ------- w : jnp.ndarray Array shape shape (Nt, Nf). WDM time-frequency-domain wavelet coefficients. """ x = jnp.asarray(x) assert x.shape == (self.N,), \ f"Input signal must have shape ({self.N},), got {x.shape=}" gnm_basis = jnp.transpose(self.gnm_basis(), (1,2,0)) w = jnp.sum(gnm_basis * x, axis=-1) * self.dt return w
[docs] @partial(jax.jit, static_argnums=0) def forward_transform_truncated(self, x : jnp.ndarray) -> jnp.ndarray: r""" Perform the forward discrete wavelet transform. Transforms the input signal from the time domain into the time-frequency domain. This method computes the wavelet coefficients using the truncated expressions .. math:: w_{n0} = \delta t\sum_{k=-K/2}^{K/2-1} g_{nm}[k + 2 n N_f] x[k + 2 n N_f] , .. math:: w_{nm} = \delta t\sum_{k=-K/2}^{K/2-1} g_{nm}[k + n N_f] x[k + n N_f] \quad \mathrm{for} \; m>0 , where the sum is over the truncated window of length :math:`K=2qN_f`. In the above expressions, indices out of bounds of the array are to be understood as wrapping around circularly. Parameters ---------- x : jnp.ndarray Array shape (N,). Input time-domain signal to be transformed. Returns ------- w : jnp.ndarray Array shape (Nt, Nf). WDM time-frequency-domain wavelet coefficients. Notes ----- This method is slow. It is only intended to be used for testing and debugging purposes. """ x = jnp.asarray(x) assert x.shape == (self.N,), \ f"Input signal must have shape ({self.N},), got {x.shape=}" w = jnp.zeros((self.Nt, self.Nf), dtype=self.jax_dtype) B = self.gnm_basis() k_vals = jnp.arange(-self.K//2, self.K//2) for n in range(self.Nt): for m in range(not self.calc_m0, self.Nf): # start at m=0 or 1 gnm = B[:, n, m] gnm_x = gnm[(k_vals+(1 if m>0 else 2)*n*self.Nf)%self.N] * \ x[(k_vals+(1 if m>0 else 2)*n*self.Nf)%self.N] w = w.at[n, m].set(self.dt*jnp.sum(gnm_x)) return w
[docs] @partial(jax.jit, static_argnums=0) def forward_transform_truncated_window(self, x : jnp.ndarray) -> jnp.ndarray: r""" Perform the forward discrete wavelet transform. Transforms the input signal from the time domain into the time-frequency domain. This method computes the wavelet coefficients using the truncated expressions using the window function: .. math:: w_{n0} = \delta t \begin{cases} \sum_{k=-K/2}^{K/2-1} x[k+2nN_f]\phi[k] & \mathrm{if}\;n<N_t/2 \\ \sum_{k=-K/2}^{K/2-1} (-1)^k x[k+2nN_f]\phi[k] & \mathrm{if}\;n\geq N_t/2 \\ \end{cases} , .. math:: w_{nm} = \sqrt{2}\delta t \, \mathrm{Re} \sum_{k=-K/2}^{K/2-1} C^*_{nm} \exp\left(\frac{i\pi km}{N_f}\right) x[k+nN_f] \phi[k] \quad \mathrm{for}\; m>0. Parameters ---------- x : jnp.ndarray Array shape (N,). Input time-domain signal to be transformed. Returns ------- w : jnp.ndarray Array shape (Nt, Nf). WDM time-frequency-domain wavelet coefficients. Notes ----- This method is slow. It is only intended to be used for testing and debugging purposes. """ x = jnp.asarray(x) assert x.shape == (self.N,), \ f"Input signal must have shape ({self.N},), got {x.shape=}" w = jnp.zeros((self.Nt, self.Nf), dtype=self.jax_dtype) n_vals = jnp.arange(self.Nt) m_vals = jnp.arange(self.Nf) k_vals = jnp.arange(-self.K//2, self.K//2) k_plus_n = (k_vals[:,jnp.newaxis]+n_vals[jnp.newaxis,:]*self.Nf)%self.N mk = m_vals[jnp.newaxis,jnp.newaxis,:]*k_vals[:,jnp.newaxis,jnp.newaxis] w = jnp.sqrt(2.) * self.dt * \ jnp.sum( jnp.conj(self.Cnm[jnp.newaxis,:,:]) * \ jnp.exp((1j)*jnp.pi*mk/self.Nf) * \ x[k_plus_n][:,:,jnp.newaxis] * \ self.window_TD[k_vals%self.N,jnp.newaxis,jnp.newaxis], axis=0).real if self.calc_m0: # overwrite m=0 terms for n<Nt/2 (zero-frequency terms) n_vals = jnp.arange(self.Nt//2) k_plus_2n = (k_vals[:,jnp.newaxis]+2*n_vals[jnp.newaxis,:]*self.Nf) f0_term = self.dt * jnp.sum( self.window_TD[k_vals%self.N, jnp.newaxis] * \ x[k_plus_2n%self.N], axis=0) w = w.at[n_vals, 0].set(f0_term) # overwrite m=0 terms for n>=Nt/2 (Nyquist-frequency terms) n_vals = jnp.arange(self.Nt//2, self.Nt) fNy_term = self.dt * jnp.sum( (-1)**k_vals[:,jnp.newaxis] * \ self.window_TD[k_vals%self.N, jnp.newaxis] * \ x[k_plus_2n%self.N], axis=0) w = w.at[n_vals, 0].set(fNy_term) return w
[docs] @partial(jax.jit, static_argnums=0) def forward_transform_short_fft(self, x : jnp.ndarray) -> jnp.ndarray: r""" Perform the forward discrete wavelet transform. Transforms the input signal from the time domain into the time-frequency domain. For the :math:`m>0` terms, the wavelet coefficients are calculated using the following expression, .. math:: w_{nm} = \sqrt{2} \delta t \, \mathrm{Re}\, C_{nm}^* X_n[mq] , where the short FFT is defined as .. math:: X_n[j] = \sum_{k=-K/2}^{K/2-1} \exp(2\pi i kj/K) x[nN_f+k] \phi[k]. The :math:`m=0` terms, if required, are calculated using the same method as in `forward_transform_truncated_window`. Parameters ---------- x : jnp.ndarray Array shape (N,). Input time-domain signal to be transformed. Returns ------- w : jnp.ndarray of shape (Nt, Nf) WDM time-frequency-domain wavelet coefficients. Notes ----- This method is fairly fast. But `forward_transform_fft` is usually faster. This is included for testing and debugging purposes. """ x = jnp.asarray(x) assert x.shape == (self.N,), \ f"Input signal must have shape ({self.N},), got {x.shape=}" X = self.short_fft(x) m_vals = jnp.arange(self.Nf) w = jnp.sqrt(2.) * self.dt * \ jnp.real( jnp.conj(self.Cnm) * X[:,(m_vals*self.q)%self.K] ) k_vals = jnp.arange(-self.K//2, self.K//2) if self.calc_m0: # overwrite m=0 terms for n<Nt/2 (zero-frequency terms) n_vals = jnp.arange(self.Nt//2) k_plus_2n = (k_vals[:,jnp.newaxis]+2*n_vals[jnp.newaxis,:]*self.Nf) f0_term = self.dt * jnp.sum( self.window_TD[k_vals%self.N, jnp.newaxis] * \ x[k_plus_2n%self.N], axis=0) w = w.at[n_vals, 0].set(f0_term) # overwrite m=0 terms for n>=Nt/2 (Nyquist-frequency terms) n_vals = jnp.arange(self.Nt//2, self.Nt) fNy_term = self.dt * jnp.sum( (-1)**k_vals[:,jnp.newaxis] * \ self.window_TD[k_vals%self.N, jnp.newaxis] * \ x[k_plus_2n%self.N], axis=0) w = w.at[n_vals, 0].set(fNy_term) return w
[docs] @partial(jax.jit, static_argnums=0) def forward_transform_fft(self, x : jnp.ndarray) -> jnp.ndarray: r""" Perform the forward discrete wavelet transform. Transforms the input signal from the time domain into the time-frequency domain. For the :math:`m>0` terms, the wavelet coefficients are calculated using the following expression, .. math:: w_{nm} = \frac{\sqrt{2}\delta t}{N} (-1)^{nm} \,\mathrm{Re}\, \Big( C_{nm}^* x_m[n] \Big) where .. math:: x_m[n] = \sum_{l=-N_t/2}^{N_t/2-1} \exp\left(\frac{2\pi i nl}{N_t} \right) \Phi[l] X[l-mN_t/2] . The :math:`m=0` terms, if required, are calculated using the same method as in `forward_transform_truncated_window`. This is vectorised to allow for batch jobs computing the dwt for multiple time series at once; note the shapes of the input and output arrays. Parameters ---------- x : jnp.ndarray The time-domain signal. Array shape (..., N). Returns ------- w : jnp.ndarray Wavelet coefficients. Array shape (..., Nt, Nf). Notes ----- This method is fast. Use this to perform discrete wavelet transforms for production analysis. This method is called by `self.dwt`. """ x = jnp.asarray(x, dtype=self.jax_dtype) assert x.shape[-1:] == (self.N,), \ f"Input signal must have shape({self.Nt}, {self.Nf}), " \ f"got {x.shape[-1:]=}." leading = x.shape[:-1] l_vals = jnp.arange(-self.Nt//2, self.Nt//2) n_vals = jnp.arange(self.Nt) m_vals = jnp.arange(self.Nf) mask = l_vals[:,jnp.newaxis] - \ m_vals[jnp.newaxis,:]*self.Nt//2 X = jnp.fft.fft(x, axis=-1) * self.dt X = jnp.take(X, mask, axis=-1, mode='wrap') Phi = jnp.fft.ifftshift(self.window_FD)[*(jnp.newaxis,)*len(leading), l_vals, jnp.newaxis] x_mn = self.Nt * jnp.fft.ifft(Phi*X, axis=-2) w = jnp.sqrt(2.) * self.df * \ (-1)**(n_vals[:,jnp.newaxis] * m_vals[jnp.newaxis,:]) * \ jnp.real( jnp.conj(self.Cnm[:,:]) * x_mn ) * \ (-1)**(n_vals[:,jnp.newaxis]) k_vals = jnp.arange(-self.K//2, self.K//2) if self.calc_m0: # overwrite m=0 terms for n<Nt/2 (zero-frequency terms) n_vals = jnp.arange(self.Nt//2) k_plus_2n = (k_vals[:,jnp.newaxis]+2*n_vals[jnp.newaxis,:]*self.Nf) f0_term = self.dt * jnp.sum( self.window_TD[k_vals%self.N, jnp.newaxis] * \ jnp.take(x, k_plus_2n, axis=-1, mode='wrap'), axis=-2) w = w.at[..., n_vals, 0].set(f0_term) # overwrite m=0 terms for n>=Nt/2 (Nyquist-frequency terms) n_vals = jnp.arange(self.Nt//2, self.Nt) fNy_term = self.dt * jnp.sum( (-1)**k_vals[:,jnp.newaxis] * \ self.window_TD[k_vals%self.N, jnp.newaxis] * \ jnp.take(x, k_plus_2n, axis=-1, mode='wrap'), axis=-2) w = w.at[..., n_vals, 0].set(fNy_term) return w
[docs] @partial(jax.jit, static_argnums=0) def inverse_transform(self, w : jnp.ndarray) -> jnp.ndarray: r""" Perform the inverse discrete wavelet transform. Transforms the wavelet coefficients from the time-frequency domain into the time domain. This method computes the inverse dwt using the truncated wavelets. This is also vectorised to allow for batch jobs computing the idwt for multiple sets of wavelet coefficients at once; note the shapes of the input and output arrays. Parameters ---------- w : jnp.ndarray Wavelet coefficients. Array shape (..., Nt, Nf). Returns ------- x : jnp.ndarray The time-domain signal. Array shape (..., N). """ w = jnp.asarray(w, dtype=self.jax_dtype) assert w.shape[-2:] == (self.Nt, self.Nf), \ f"Input coefficients must have shape ({self.Nt}, {self.Nf}), " \ f"got {w.shape[-2:].shape=}." leading = w.shape[:-2] x = jnp.zeros(leading+(self.N,), dtype=self.jax_dtype) @jax.jit def add_one_time(x, n): k_vals = jnp.arange(-self.K//2, self.K//2) indices = (k_vals+n*self.Nf)%self.N @jax.jit def add_one_freq(x, m): shift = ((n+m)%2) * jnp.pi/2. wavelet = jnp.sqrt(2.) * (-1)**(n*m) * \ jnp.cos(jnp.pi*m*indices/self.Nf-shift) * \ self.window_TD[k_vals] coeff = jnp.atleast_1d(w[...,n,m]) term = coeff[..., None] * wavelet[None, ...] updates_shape = x[..., indices].shape x = x.at[..., indices].add(jnp.reshape(term, updates_shape)) return x x = jax.lax.fori_loop(1, # only sum over m>0 self.Nf, lambda m, acc: add_one_freq(acc, m), x) return x x = jax.lax.fori_loop(0, self.Nt, lambda n, acc: add_one_time(acc, n), x) if self.calc_m0: # overwrite m=0 terms for n<Nt/2 (zero-frequency terms) n_vals = jnp.arange(self.Nt//2) @jax.jit def add_zero_freq(x, n): k_vals = jnp.arange(-self.K//2, self.K//2) wavelet = self.window_TD[k_vals] indices = (k_vals+2*n*self.Nf)%self.N coeff = jnp.atleast_1d(w[...,n,0]) term = coeff[..., None] * wavelet[None, ...] updates_shape = x[..., indices].shape x = x.at[..., indices].add(jnp.reshape(term, updates_shape)) return x x = jax.lax.fori_loop(0, self.Nt//2, lambda n, acc: add_zero_freq(acc, n), x) @jax.jit def add_Nyquist_freq(x, n): k_vals = jnp.arange(-self.K//2, self.K//2) wavelet = (-1)**(k_vals) * self.window_TD[k_vals] indices = (k_vals+2*n*self.Nf)%self.N coeff = jnp.atleast_1d(w[...,n,0]) term = coeff[..., None] * wavelet[None, ...] updates_shape = x[..., indices].shape x = x.at[..., indices].add(jnp.reshape(term, updates_shape)) return x x = jax.lax.fori_loop(self.Nt//2, self.Nt, lambda n, acc: add_Nyquist_freq(acc, n), x) return x
[docs] def inverse_transform_exact(self, w : jnp.ndarray) -> jnp.ndarray: r""" Perform the inverse discrete wavelet transform. Transforms the wavelet coefficients from the time-frequency domain into the time domain. This method computes the inverse dwt direcrtly using the expression .. math:: x[k] = \sum_{n=0}^{N_t-1} \sum_{m=0}^{N_f-1} w_{nm} g_{nm}[k] . This method is slow and very memory inefficient. It is here mainly for testing. Consider using `inverse_transform` instead. Parameters ---------- w : jnp.ndarray Array shape (Nt, Nf). WDM time-frequency-domain wavelet coefficients. Returns ------- x : jnp.ndarray Array shape (N,). The time-domain signal. """ w = jnp.asarray(w, dtype=self.jax_dtype) assert w.shape == (self.Nt, self.Nf), \ f"Input coefficients must have shape ({self.Nt}, {self.Nf}), " \ f"got {w.shape=}." gnm_basis = self.gnm_basis() wg = w * gnm_basis wg = wg.reshape(wg.shape[0], -1) x = jnp.sum(wg, axis=-1) return x
[docs] def dwt(self, x : jnp.ndarray) -> jnp.ndarray: r""" Forward discrete wavelet transform. Calls `self.fast_forward_transform`. Vectorised to allow for transforming multiple time series at once. Parameters ---------- x : jnp.ndarray Input time series. Array shape=(N,) or (..., N). Returns ------- w : jnp.ndarray Wavelet coefficients. Array shape=(Nt, Nf) or (..., Nt, Nf). """ x = jnp.asarray(x, dtype=self.jax_dtype) assert jnp.all(jnp.isreal(x)), "time series must be real." return self.forward_transform_fft(x)
[docs] def idwt(self, w : jnp.ndarray) -> jnp.ndarray: r""" Inverse discrete wavelet transform. Calls `self.inverse_transform`. Vectorised to allow for transforming multiple time series at once. Parameters ---------- w : jnp.ndarray Wavelet coefficients. Array shape=(Nt, Nf) or (..., Nt, Nf). Returns ------- x : jnp.ndarray Input time series. Array shape=(N,) or (..., N). """ w = jnp.asarray(w, dtype=self.jax_dtype) assert jnp.all(jnp.isreal(w)), "wavelet coefficients must be real." return self.inverse_transform(w)
def __repr__(self) -> str: r""" String representation of the WDM_transform instance. Returns ------- text : str A string representation of WDM_transform instance. """ lines = [] lines.append( (f"WDM_transform(Nf={self.Nf}, N={self.N}, q={self.q}, " f"d={self.d}, A_frac={self.A_frac}, calc_m0={self.calc_m0})") ) lines.append( f"{self.Nt = } time cells" ) lines.append( f"{self.Nf = } frequency cells" ) lines.append( f"{self.dT = } time resolution" ) lines.append( f"{self.dF = } frequency resolution" ) lines.append( f"{self.K = } window length" ) text = "\n".join(lines) return text def __call__(self, x : jnp.ndarray) -> jnp.ndarray: r""" Calls the forward transform self.dwt. """ return self.dwt(x)