pywt
1import numpy as np
2from matplotlib import pyplot as plt
3
4import pywt
5
6img = pywt.data.camera().astype(float)
7
8# Fully separable transform
9fswavedecn_result = pywt.fswavedecn(img, 'db2', 'periodization', levels=4)
10
11# Standard DWT
12coefs = pywt.wavedec2(img, 'db2', 'periodization', level=4)
13# convert DWT coefficients to a 2D array
14mallat_array, mallat_slices = pywt.coeffs_to_array(coefs)
15
16
17fig, (ax1, ax2) = plt.subplots(1, 2)
18
19ax1.imshow(np.abs(mallat_array)**0.25,
20cmap=plt.cm.gray,
21interpolation='nearest')
22ax1.set_axis_off()
23ax1.set_title('Mallat decomposition\n(wavedec2)')
24
25ax2.imshow(np.abs(fswavedecn_result.coeffs)**0.25,
26cmap=plt.cm.gray,
27interpolation='nearest')
28ax2.set_axis_off()
29ax2.set_title('Fully separable decomposition\n(fswt)')
30
31plt.show()
32