Skip to content

Commit de7aaa7

Browse files
authored
Merge pull request #27 from v0lta/0.1-release-candidate
version 0.1 - release candidate
2 parents e67a25d + 15a9538 commit de7aaa7

File tree

18 files changed

+424
-63
lines changed

18 files changed

+424
-63
lines changed

.bumpversion.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.0.19-dev
2+
current_version = 0.1.0-dev
33
commit = True
44
tag = False
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(?:-(?P<release>[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(?P<build>[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?
96.8 KB
Loading
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
import numpy as np
3+
import src.ptwt as ptwt
4+
import matplotlib.pyplot as plt
5+
import scipy.signal as signal
6+
7+
if __name__ == "__main__":
8+
t = np.linspace(-2, 2, 800, endpoint=False)
9+
sig = signal.chirp(t, f0=1, f1=12, t1=2, method="linear")
10+
widths = np.arange(1, 31)
11+
cwtmatr_pt, freqs = ptwt.cwt(
12+
torch.from_numpy(sig), widths, "mexh", sampling_period=(4 / 800) * np.pi
13+
)
14+
cwtmatr = cwtmatr_pt.numpy()
15+
fig, axs = plt.subplots(2)
16+
axs[0].plot(t, sig)
17+
axs[0].set_ylabel("magnitude")
18+
axs[1].imshow(
19+
cwtmatr,
20+
cmap="PRGn",
21+
aspect="auto",
22+
vmax=abs(cwtmatr).max(),
23+
vmin=-abs(cwtmatr).max(),
24+
extent=[min(t), max(t), min(freqs), max(freqs)],
25+
)
26+
axs[1].set_xlabel("time")
27+
axs[1].set_ylabel("frequency")
28+
plt.show()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#### Cwt chirp analysis
2+
To run this example, clone the repository and type
3+
```ipython examples/wavelet_packet_chirp_analysis/cwt_chirp_analysis.py```
4+
5+
The result should look like this:
6+
7+
![alt text](chirp_cwt.png)
-11.1 KB
Loading

examples/wavelet_packet_chirp_analysis/chirp_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
np_lst = []
1919
for node in nodes:
2020
np_lst.append(wp[node])
21-
viz = np.stack(np_lst)
21+
viz = np.stack(np_lst).squeeze()
2222

2323
fig, axs = plt.subplots(2)
2424
axs[0].plot(t, w)

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
##########################
44
[metadata]
55
name = ptwt
6-
version = 0.0.19-dev
6+
version = 0.1.0-dev
77
description = Differentiable and gpu enabled fast wavelet transforms in PyTorch
88
long_description = file: README.rst
99
long_description_content_type = text/x-rst

src/ptwt/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Differentiable and gpu enabled fast wavelet transforms in PyTorch."""
22
from .conv_transform import wavedec, wavedec2, waverec, waverec2
3+
from .continuous_transform import cwt
34
from .matmul_transform import MatrixWavedec, MatrixWaverec
4-
from .matmul_transform_2d import MatrixWavedec2d, MatrixWaverec2d
5+
from .matmul_transform_2 import MatrixWavedec2, MatrixWaverec2
56
from .packets import WaveletPacket, WaveletPacket2D

src/ptwt/continuous_transform.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
"""PyTorch compatible cwt code.
2+
3+
Based on https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py
4+
"""
5+
from typing import Tuple, Union
6+
7+
import numpy as np
8+
import torch
9+
from pywt import ContinuousWavelet, DiscreteContinuousWavelet, Wavelet
10+
from pywt._functions import integrate_wavelet, scale2frequency
11+
from torch.fft import fft, ifft
12+
13+
14+
def _next_fast_len(n: int) -> int:
15+
"""Round up size to the nearest power of two.
16+
17+
Given a number of samples `n`, returns the next power of two
18+
following this number to take advantage of FFT speedup.
19+
This fallback is less efficient than `scipy.fftpack.next_fast_len`
20+
Taken from:
21+
https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py
22+
"""
23+
return int(2 ** np.ceil(np.log2(n)))
24+
25+
26+
def cwt(
27+
data: torch.Tensor,
28+
scales: Union[np.ndarray, torch.Tensor], # type: ignore
29+
wavelet: Union[ContinuousWavelet, str],
30+
sampling_period: float = 1.0,
31+
) -> Tuple[torch.Tensor, np.ndarray]: # type: ignore
32+
"""Compute the single dimensional continuous wavelet transform.
33+
34+
This function is a PyTorch port of pywt.cwt as found at:
35+
https://github.com/PyWavelets/pywt/blob/master/pywt/_cwt.py
36+
37+
Args:
38+
data (torch.Tensor): The input tensor of shape [batch_size, time].
39+
scales (torch.Tensor or np.array):
40+
The wavelet scales to use. One can use
41+
``f = pywt.scale2frequency(wavelet, scale)/sampling_period`` to determine
42+
what physical frequency, ``f``. Here, ``f`` is in hertz when the
43+
``sampling_period`` is given in seconds.
44+
wavelet (str or Wavelet of ContinuousWavelet): The wavelet to work with.
45+
wavelet (ContinuousWavelet or str): The continuous wavelet to work with.
46+
sampling_period (float): Sampling period for the frequencies output (optional).
47+
The values computed for ``coefs`` are independent of the choice of
48+
``sampling_period`` (i.e. ``scales`` is not scaled by the sampling
49+
period).
50+
51+
Raises:
52+
ValueError: If a scale is too small for the input signal.
53+
54+
Returns:
55+
Tuple[torch.Tensor, np.ndarray]: A tuple with the transformation matrix
56+
and frequencies in this order.
57+
"""
58+
# accept array_like input; make a copy to ensure a contiguous array
59+
if not isinstance(wavelet, (ContinuousWavelet, Wavelet)):
60+
wavelet = DiscreteContinuousWavelet(wavelet)
61+
if type(scales) is torch.Tensor:
62+
scales = scales.numpy()
63+
elif np.isscalar(scales):
64+
scales = np.array([scales])
65+
# if not np.isscalar(axis):
66+
# raise np.AxisError("axis must be a scalar.")
67+
68+
precision = 10
69+
int_psi, x = integrate_wavelet(wavelet, precision=precision)
70+
if type(wavelet) is ContinuousWavelet:
71+
int_psi = np.conj(int_psi) if wavelet.complex_cwt else int_psi
72+
int_psi = torch.tensor(int_psi, device=data.device)
73+
74+
# convert int_psi, x to the same precision as the data
75+
x = np.asarray(x, dtype=data.cpu().numpy().real.dtype)
76+
77+
size_scale0 = -1
78+
fft_data = None
79+
80+
out = []
81+
for scale in scales:
82+
step = x[1] - x[0]
83+
j = np.arange(scale * (x[-1] - x[0]) + 1) / (scale * step)
84+
j = j.astype(int) # floor
85+
if j[-1] >= len(int_psi):
86+
j = np.extract(j < len(int_psi), j)
87+
int_psi_scale = int_psi[j].flip(0)
88+
89+
# The padding is selected for:
90+
# - optimal FFT complexity
91+
# - to be larger than the two signals length to avoid circular
92+
# convolution
93+
size_scale = _next_fast_len(data.shape[-1] + len(int_psi_scale) - 1)
94+
if size_scale != size_scale0:
95+
# Must recompute fft_data when the padding size changes.
96+
fft_data = fft(data, size_scale, dim=-1)
97+
size_scale0 = size_scale
98+
fft_wav = fft(int_psi_scale, size_scale, dim=-1)
99+
conv = ifft(fft_wav * fft_data, dim=-1)
100+
conv = conv[..., : data.shape[-1] + len(int_psi_scale) - 1]
101+
102+
coef = -np.sqrt(scale) * torch.diff(conv, dim=-1)
103+
104+
# transform axis is always -1
105+
d = (coef.shape[-1] - data.shape[-1]) / 2.0
106+
if d > 0:
107+
coef = coef[..., int(np.floor(d)) : -int(np.ceil(d))]
108+
elif d < 0:
109+
raise ValueError("Selected scale of {} too small.".format(scale))
110+
111+
out.append(coef)
112+
out_tensor = torch.stack(out)
113+
if type(wavelet) is Wavelet:
114+
out_tensor = out_tensor.real
115+
else:
116+
out_tensor = out_tensor if wavelet.complex_cwt else out_tensor.real
117+
118+
frequencies = scale2frequency(wavelet, scales, precision)
119+
if np.isscalar(frequencies):
120+
frequencies = np.array([frequencies])
121+
frequencies /= sampling_period
122+
return out_tensor, frequencies

src/ptwt/conv_transform.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _create_tensor(filter: Sequence[float]) -> torch.Tensor:
5353
return dec_lo_tensor, dec_hi_tensor, rec_lo_tensor, rec_hi_tensor
5454

5555

56-
def get_pad(data_len: int, filt_len: int) -> Tuple[int, int]:
56+
def _get_pad(data_len: int, filt_len: int) -> Tuple[int, int]:
5757
"""Compute the required padding.
5858
5959
Args:
@@ -107,12 +107,12 @@ def fwt_pad(
107107
# convert pywt to pytorch convention.
108108
mode = "constant"
109109

110-
padr, padl = get_pad(data.shape[-1], len(wavelet.dec_lo))
110+
padr, padl = _get_pad(data.shape[-1], len(wavelet.dec_lo))
111111
data_pad = torch.nn.functional.pad(data, [padl, padr], mode=mode)
112112
return data_pad
113113

114114

115-
def fwt_pad2d(
115+
def fwt_pad2(
116116
data: torch.Tensor, wavelet: Union[Wavelet, str], level: int, mode: str = "reflect"
117117
) -> torch.Tensor:
118118
"""Pad data for the 2d FWT.
@@ -129,8 +129,8 @@ def fwt_pad2d(
129129
130130
"""
131131
wavelet = _as_wavelet(wavelet)
132-
padb, padt = get_pad(data.shape[-2], len(wavelet.dec_lo))
133-
padr, padl = get_pad(data.shape[-1], len(wavelet.dec_lo))
132+
padb, padt = _get_pad(data.shape[-2], len(wavelet.dec_lo))
133+
padr, padl = _get_pad(data.shape[-1], len(wavelet.dec_lo))
134134
data_pad = torch.nn.functional.pad(data, [padl, padr, padt, padb], mode=mode)
135135
return data_pad
136136

@@ -254,7 +254,7 @@ def wavedec2(
254254
] = []
255255
res_ll = data
256256
for s in range(level):
257-
res_ll = fwt_pad2d(res_ll, wavelet, level=s, mode=mode)
257+
res_ll = fwt_pad2(res_ll, wavelet, level=s, mode=mode)
258258
res = torch.nn.functional.conv2d(res_ll, dec_filt, stride=2)
259259
res_ll, res_lh, res_hl, res_hh = torch.split(res, 1, 1)
260260
result_lst.append((res_lh, res_hl, res_hh))

0 commit comments

Comments
 (0)