|
| 1 | +"""PyTorch compatible cwt code. |
| 2 | +
|
| 3 | +Based on https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py |
| 4 | +""" |
| 5 | +from typing import Tuple, Union |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import torch |
| 9 | +from pywt import ContinuousWavelet, DiscreteContinuousWavelet, Wavelet |
| 10 | +from pywt._functions import integrate_wavelet, scale2frequency |
| 11 | +from torch.fft import fft, ifft |
| 12 | + |
| 13 | + |
| 14 | +def _next_fast_len(n: int) -> int: |
| 15 | + """Round up size to the nearest power of two. |
| 16 | +
|
| 17 | + Given a number of samples `n`, returns the next power of two |
| 18 | + following this number to take advantage of FFT speedup. |
| 19 | + This fallback is less efficient than `scipy.fftpack.next_fast_len` |
| 20 | + Taken from: |
| 21 | + https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py |
| 22 | + """ |
| 23 | + return int(2 ** np.ceil(np.log2(n))) |
| 24 | + |
| 25 | + |
| 26 | +def cwt( |
| 27 | + data: torch.Tensor, |
| 28 | + scales: Union[np.ndarray, torch.Tensor], # type: ignore |
| 29 | + wavelet: Union[ContinuousWavelet, str], |
| 30 | + sampling_period: float = 1.0, |
| 31 | +) -> Tuple[torch.Tensor, np.ndarray]: # type: ignore |
| 32 | + """Compute the single dimensional continuous wavelet transform. |
| 33 | +
|
| 34 | + This function is a PyTorch port of pywt.cwt as found at: |
| 35 | + https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py |
| 36 | +
|
| 37 | + Args: |
| 38 | + data (torch.Tensor): The input tensor of shape [batch_size, time]. |
| 39 | + scales (torch.Tensor or np.array): |
| 40 | + The wavelet scales to use. One can use |
| 41 | + ``f = pywt.scale2frequency(wavelet, scale)/sampling_period`` to determine |
| 42 | + what physical frequency, ``f``. Here, ``f`` is in hertz when the |
| 43 | + ``sampling_period`` is given in seconds. |
| 44 | + wavelet (str or Wavelet of ContinuousWavelet): The wavelet to work with. |
| 45 | + wavelet (ContinuousWavelet or str): The continuous wavelet to work with. |
| 46 | + sampling_period (float): Sampling period for the frequencies output (optional). |
| 47 | + The values computed for ``coefs`` are independent of the choice of |
| 48 | + ``sampling_period`` (i.e. ``scales`` is not scaled by the sampling |
| 49 | + period). |
| 50 | +
|
| 51 | + Raises: |
| 52 | + ValueError: If a scale is too small for the input signal. |
| 53 | +
|
| 54 | + Returns: |
| 55 | + Tuple[torch.Tensor, np.ndarray]: A tuple with the transformation matrix |
| 56 | + and frequencies in this order. |
| 57 | + """ |
| 58 | + # accept array_like input; make a copy to ensure a contiguous array |
| 59 | + if not isinstance(wavelet, (ContinuousWavelet, Wavelet)): |
| 60 | + wavelet = DiscreteContinuousWavelet(wavelet) |
| 61 | + if type(scales) is torch.Tensor: |
| 62 | + scales = scales.numpy() |
| 63 | + elif np.isscalar(scales): |
| 64 | + scales = np.array([scales]) |
| 65 | + # if not np.isscalar(axis): |
| 66 | + # raise np.AxisError("axis must be a scalar.") |
| 67 | + |
| 68 | + precision = 10 |
| 69 | + int_psi, x = integrate_wavelet(wavelet, precision=precision) |
| 70 | + if type(wavelet) is ContinuousWavelet: |
| 71 | + int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi |
| 72 | + int_psi = torch.tensor(int_psi, device=data.device) |
| 73 | + |
| 74 | + # convert int_psi, x to the same precision as the data |
| 75 | + x = np.asarray(x, dtype=data.cpu().numpy().real.dtype) |
| 76 | + |
| 77 | + size_scale0 = -1 |
| 78 | + fft_data = None |
| 79 | + |
| 80 | + out = [] |
| 81 | + for scale in scales: |
| 82 | + step = x[1] - x[0] |
| 83 | + j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step) |
| 84 | + j = j.astype(int) # floor |
| 85 | + if j[-1] >= len(int_psi): |
| 86 | + j = np.extract(j < len(int_psi), j) |
| 87 | + int_psi_scale = int_psi[j].flip(0) |
| 88 | + |
| 89 | + # The padding is selected for: |
| 90 | + # - optimal FFT complexity |
| 91 | + # - to be larger than the two signals length to avoid circular |
| 92 | + # convolution |
| 93 | + size_scale = _next_fast_len(data.shape[-1] + len(int_psi_scale) - 1) |
| 94 | + if size_scale != size_scale0: |
| 95 | + # Must recompute fft_data when the padding size changes. |
| 96 | + fft_data = fft(data, size_scale, dim=-1) |
| 97 | + size_scale0 = size_scale |
| 98 | + fft_wav = fft(int_psi_scale, size_scale, dim=-1) |
| 99 | + conv = ifft(fft_wav * fft_data, dim=-1) |
| 100 | + conv = conv[..., : data.shape[-1] + len(int_psi_scale) - 1] |
| 101 | + |
| 102 | + coef = -np.sqrt(scale) * torch.diff(conv, dim=-1) |
| 103 | + |
| 104 | + # transform axis is always -1 |
| 105 | + d = (coef.shape[-1] - data.shape[-1]) / 2.0 |
| 106 | + if d > 0: |
| 107 | + coef = coef[..., int(np.floor(d)) : -int(np.ceil(d))] |
| 108 | + elif d < 0: |
| 109 | + raise ValueError("Selected scale of {} too small.".format(scale)) |
| 110 | + |
| 111 | + out.append(coef) |
| 112 | + out_tensor = torch.stack(out) |
| 113 | + if type(wavelet) is Wavelet: |
| 114 | + out_tensor = out_tensor.real |
| 115 | + else: |
| 116 | + out_tensor = out_tensor if wavelet.complex_cwt else out_tensor.real |
| 117 | + |
| 118 | + frequencies = scale2frequency(wavelet, scales, precision) |
| 119 | + if np.isscalar(frequencies): |
| 120 | + frequencies = np.array([frequencies]) |
| 121 | + frequencies /= sampling_period |
| 122 | + return out_tensor, frequencies |
0 commit comments