Source code for WDM.code.utils.utils

from functools import partial
import jax
import jax.numpy as jnp


[docs] def next_multiple(i: int, N: int) -> int: r""" Return smallest integer multiple of N greater than or equal to integer i. Parameters ---------- i : int The input number. N : int The multiple to align to. Returns ------- j : int The next multiple of N. Notes ----- Example with N = 3: >>> for i in [-4, -3, -2, -1, 0, 1, 2, 3, 4]: ... print(f"{i} -> {next_multiple(i, 4)}") -4 -> -3 -3 -> -3 -2 -> 0 -1 -> 0 0 -> 0 1 -> 3 2 -> 3 3 -> 3 4 -> 6 """ j = ((i + N - 1) // N) * N return j
[docs] def C_nm(n: int, m: int) -> complex: r""" Compute the complex-valued modulation coefficient :math:`C_{nm}`. This coefficient alternates between 1 and :math:`i` to apply modulation in the WDM transform. Parameters ---------- n : int Time index. m : int Frequency index. Returns ------- complex Coefficient :math:`C_{nm}`, equal to 1 or :math:`i` depending on parity of :math:`n+m`. """ return 1.0 if (n + m) % 2 == 0 else 1.0j
[docs] @partial(jax.jit, static_argnums=(1, 2, 3)) def overlapping_windows(x: jnp.ndarray, K: int, Nt: int, Nf: int) -> jnp.ndarray: """ Extract overlapping, wrapped windows from input array `x`. Parameters ---------- x : jnp.ndarray, shape (N,) Input array to extract windows from. K : int Window length (must be even). Nt : int Number of windows (time steps). Nf : int Hop size between window centers. Returns ------- windows : jnp.ndarray, shape (Nt, K) Array of overlapping windows with wraparound indexing. Examples -------- >>> import jax.numpy as jnp >>> Nt = 4 >>> Nf = 4 >>> K = 8 >>> x = jnp.arange(Nt*Nf) >>> overlapping_windows(x, K, Nt, Nf) array([ [12, 13, 14, 15, 0, 1, 2, 3], [ 0, 1, 2, 3, 4, 5, 6, 7], [ 4, 5, 6, 7, 8, 9, 10, 11], [ 8, 9, 10, 11, 12, 13, 14, 15] ]) """ N = x.shape[0] # Centered window indices relative to each window center k_offsets = jnp.arange(-K//2, K//2) # Window center indices centers = jnp.arange(Nt) * Nf # Create full (Nt, K) index matrix with wraparound idx = (centers[:,jnp.newaxis] + k_offsets[jnp.newaxis,:]) % N return x[idx]
[docs] def pad_signal(x : jnp.ndarray, N : int, where: str = 'end') -> jnp.ndarray: r""" The transform method requires the input time series signal to have a specific length :math:`N`. This method can be used to zero-pad any signal to the desired length. This function also returns a Boolean mask that can be used later to recover arrays of the original length. Parameters ---------- x : jnp.ndarray Input signal to be padded. N : int Desired length of the output signal. where : str Where to add the padding. Options are 'end', 'start', or 'equal' which puts the zero padding at the end of the signal, the start of the signal, or equally at both ends respectively. Optional. Returns ------- x_padded : jnp.ndarray Padded signal to length N, with zeros added at the end. mask : jnp.ndarray Boolean mask indicating the valid part of the padded signal. Notes ----- The Boolean mask can be used to get back to the original signal; i.e. `x_padded[mask]` will recover the original signal, `x`. """ x = jnp.asarray(x) n = len(x) padding_length = N - n assert padding_length >= 0, \ f"Input signal length {n} exceeds desired length {N}." mask = jnp.full(N, True, dtype=bool) if where == 'end': x_padded = jnp.pad(x, (0, padding_length), mode='constant', constant_values=0) mask = mask.at[n:].set(False) elif where == 'start': x_padded = jnp.pad(x, (padding_length, 0), mode='constant', constant_values=0) mask = mask.at[:padding_length].set(False) elif where == 'equal': a = padding_length // 2 b = padding_length - a x_padded = jnp.pad(x, (a, b), mode='constant', constant_values=0) mask = mask.at[:a].set(False) mask = mask.at[n + a:].set(False) else: raise ValueError(f"Invalid padding location {where=}.") return x_padded, mask