6
from contextlib import contextmanager
7
from itertools import product
12
from torch.testing._internal.common_utils import \
13
(TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM,
14
make_tensor, skipIfTorchDynamo)
15
from torch.testing._internal.common_device_type import \
16
(instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
17
skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol)
18
from torch.testing._internal.common_methods_invocations import (
19
spectral_funcs, SpectralFuncType)
20
from torch.testing._internal.common_cuda import SM53OrLater
21
from torch._prims_common import corresponding_complex_dtype
23
from typing import Optional, List
24
from packaging import version
38
except ModuleNotFoundError:
41
REFERENCE_NORM_MODES = (
42
(None, "forward", "backward", "ortho")
43
if version.parse(np.__version__) >= version.parse('1.20.0') and (
44
not has_scipy_fft or version.parse(scipy.__version__) >= version.parse('1.6.0'))
48
def _complex_stft(x, *args, **kwargs):
50
stft_real = torch.stft(x.real, *args, **kwargs, return_complex=True, onesided=False)
51
stft_imag = torch.stft(x.imag, *args, **kwargs, return_complex=True, onesided=False)
52
return stft_real + 1j * stft_imag
55
def _hermitian_conj(x, dim):
56
"""Returns the hermitian conjugate along a single dimension
60
out = torch.empty_like(x)
61
mid = (x.size(dim) - 1) // 2
62
idx = [slice(None)] * out.dim()
63
idx_center = list(idx)
68
idx_neg[dim] = slice(-mid, None)
70
idx_pos[dim] = slice(1, mid + 1)
72
out[idx_pos] = x[idx_neg].flip(dim)
73
out[idx_neg] = x[idx_pos].flip(dim)
74
if (2 * mid + 1 < x.size(dim)):
80
def _complex_istft(x, *args, **kwargs):
83
slc = (Ellipsis, slice(None, n_fft // 2 + 1), slice(None))
85
hconj = _hermitian_conj(x, dim=-2)
86
x_hermitian = (x + hconj) / 2
87
x_antihermitian = (x - hconj) / 2
88
istft_real = torch.istft(x_hermitian[slc], *args, **kwargs, onesided=True)
89
istft_imag = torch.istft(-1j * x_antihermitian[slc], *args, **kwargs, onesided=True)
90
return torch.complex(istft_real, istft_imag)
93
def _stft_reference(x, hop_length, window):
94
r"""Reference stft implementation
96
This doesn't implement all of torch.stft, only the STFT definition:
98
.. math:: X(m, \omega) = \sum_n x[n]w[n - m] e^{-jn\omega}
101
n_fft = window.numel()
102
X = torch.empty((n_fft, (x.numel() - n_fft + hop_length) // hop_length),
103
device=x.device, dtype=torch.cdouble)
104
for m in range(X.size(1)):
105
start = m * hop_length
106
if start + n_fft > x.numel():
107
slc = torch.empty(n_fft, device=x.device, dtype=x.dtype)
109
slc[:tmp.numel()] = tmp
111
slc = x[start: start + n_fft]
112
X[:, m] = torch.fft.fft(slc * window)
116
def skip_helper_for_fft(device, dtype):
117
device_type = torch.device(device).type
118
if dtype not in (torch.half, torch.complex32):
121
if device_type == 'cpu':
122
raise unittest.SkipTest("half and complex32 are not supported on CPU")
124
raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53")
128
class TestFFT(TestCase):
131
@onlyNativeDeviceTypes
132
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD],
133
allowed_dtypes=(torch.float, torch.cfloat))
134
def test_reference_1d(self, device, dtype, op):
136
raise unittest.SkipTest("No reference implementation")
138
norm_modes = REFERENCE_NORM_MODES
142
(torch.randn(67, device=device, dtype=dtype),
143
torch.randn(80, device=device, dtype=dtype),
144
torch.randn(12, 14, device=device, dtype=dtype),
145
torch.randn(9, 6, 3, device=device, dtype=dtype)),
155
(torch.randn(4, 5, 6, 7, device=device, dtype=dtype),),
162
for iargs in test_args:
167
expected = op.ref(input.cpu().numpy(), *args)
168
exact_dtype = dtype in (torch.double, torch.complex128)
169
actual = op(input, *args)
170
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
173
@onlyNativeDeviceTypes
175
torch.half : tol(1e-2, 1e-2),
176
torch.chalf : tol(1e-2, 1e-2),
178
@dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
179
def test_fft_round_trip(self, device, dtype):
180
skip_helper_for_fft(device, dtype)
182
if dtype not in (torch.half, torch.complex32):
183
test_args = list(product(
185
(torch.randn(67, device=device, dtype=dtype),
186
torch.randn(80, device=device, dtype=dtype),
187
torch.randn(12, 14, device=device, dtype=dtype),
188
torch.randn(9, 6, 3, device=device, dtype=dtype)),
192
(None, "forward", "backward", "ortho")
196
test_args = list(product(
198
(torch.randn(64, device=device, dtype=dtype),
199
torch.randn(128, device=device, dtype=dtype),
200
torch.randn(4, 16, device=device, dtype=dtype),
201
torch.randn(8, 6, 2, device=device, dtype=dtype)),
205
(None, "forward", "backward", "ortho")
208
fft_functions = [(torch.fft.fft, torch.fft.ifft)]
210
if not dtype.is_complex:
213
fft_functions += [(torch.fft.rfft, torch.fft.irfft),
214
(torch.fft.ihfft, torch.fft.hfft)]
216
for forward, backward in fft_functions:
217
for x, dim, norm in test_args:
224
y = backward(forward(x, **kwargs), **kwargs)
225
if x.dtype is torch.half and y.dtype is torch.complex32:
228
x = x.to(torch.complex32)
230
self.assertEqual(x, y, exact_dtype=(
231
forward != torch.fft.fft or x.is_complex()))
234
@onlyNativeDeviceTypes
235
@ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat))
236
def test_empty_fft(self, device, dtype, op):
237
t = torch.empty(1, 0, device=device, dtype=dtype)
238
match = r"Invalid number of data points \([-\d]*\) specified"
240
with self.assertRaisesRegex(RuntimeError, match):
243
@onlyNativeDeviceTypes
244
def test_empty_ifft(self, device):
245
t = torch.empty(2, 1, device=device, dtype=torch.complex64)
246
match = r"Invalid number of data points \([-\d]*\) specified"
248
for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn,
249
torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]:
250
with self.assertRaisesRegex(RuntimeError, match):
253
@onlyNativeDeviceTypes
254
def test_fft_invalid_dtypes(self, device):
255
t = torch.randn(64, device=device, dtype=torch.complex128)
257
with self.assertRaisesRegex(RuntimeError, "rfft expects a real input tensor"):
260
with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"):
263
with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"):
267
@onlyNativeDeviceTypes
268
@dtypes(torch.int8, torch.half, torch.float, torch.double,
269
torch.complex32, torch.complex64, torch.complex128)
270
def test_fft_type_promotion(self, device, dtype):
271
skip_helper_for_fft(device, dtype)
273
if dtype.is_complex or dtype.is_floating_point:
274
t = torch.randn(64, device=device, dtype=dtype)
276
t = torch.randint(-2, 2, (64,), device=device, dtype=dtype)
279
torch.int8: torch.complex64,
280
torch.half: torch.complex32,
281
torch.float: torch.complex64,
282
torch.double: torch.complex128,
283
torch.complex32: torch.complex32,
284
torch.complex64: torch.complex64,
285
torch.complex128: torch.complex128,
288
self.assertEqual(T.dtype, PROMOTION_MAP[dtype])
290
PROMOTION_MAP_C2R = {
291
torch.int8: torch.float,
292
torch.half: torch.half,
293
torch.float: torch.float,
294
torch.double: torch.double,
295
torch.complex32: torch.half,
296
torch.complex64: torch.float,
297
torch.complex128: torch.double,
299
if dtype in (torch.half, torch.complex32):
303
x = torch.randn(65, device=device, dtype=dtype)
304
R = torch.fft.hfft(x)
306
R = torch.fft.hfft(t)
307
self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype])
309
if not dtype.is_complex:
310
PROMOTION_MAP_R2C = {
311
torch.int8: torch.complex64,
312
torch.half: torch.complex32,
313
torch.float: torch.complex64,
314
torch.double: torch.complex128,
316
C = torch.fft.rfft(t)
317
self.assertEqual(C.dtype, PROMOTION_MAP_R2C[dtype])
319
@onlyNativeDeviceTypes
320
@ops(spectral_funcs, dtypes=OpDTypes.unsupported,
321
allowed_dtypes=[torch.half, torch.bfloat16])
322
def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
324
sample = first_sample(self, op.sample_inputs(device, dtype))
325
device_type = torch.device(device).type
326
default_msg = "Unsupported dtype"
327
if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
328
err_msg = default_msg
329
elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
330
err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
332
err_msg = default_msg
333
with self.assertRaisesRegex(RuntimeError, err_msg):
334
op(sample.input, *sample.args, **sample.kwargs)
336
@onlyNativeDeviceTypes
337
@ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf))
338
def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op):
339
t = make_tensor(13, 13, device=device, dtype=dtype)
340
err_msg = "cuFFT only supports dimensions whose sizes are powers of two"
341
with self.assertRaisesRegex(RuntimeError, err_msg):
344
if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD):
345
kwargs = {'s': (12, 12)}
349
with self.assertRaisesRegex(RuntimeError, err_msg):
353
@onlyNativeDeviceTypes
354
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
355
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
356
allowed_dtypes=(torch.cfloat, torch.cdouble))
357
def test_reference_nd(self, device, dtype, op):
359
raise unittest.SkipTest("No reference implementation")
361
norm_modes = REFERENCE_NORM_MODES
365
*product(range(2, 5), (None,), (None, (0,), (0, -1))),
366
*product(range(2, 5), (None, (4, 10)), (None,)),
368
(5, None, (1, 3, 4)),
372
(4, (10, 10), (0, 1))
375
for input_ndim, s, dim in transform_desc:
376
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
377
input = torch.randn(*shape, device=device, dtype=dtype)
379
for norm in norm_modes:
380
expected = op.ref(input.cpu().numpy(), s, dim, norm)
381
exact_dtype = dtype in (torch.double, torch.complex128)
382
actual = op(input, s, dim, norm)
383
self.assertEqual(actual, expected, exact_dtype=exact_dtype)
386
@onlyNativeDeviceTypes
388
torch.half : tol(1e-2, 1e-2),
389
torch.chalf : tol(1e-2, 1e-2),
391
@dtypes(torch.half, torch.float, torch.double,
392
torch.complex32, torch.complex64, torch.complex128)
393
def test_fftn_round_trip(self, device, dtype):
394
skip_helper_for_fft(device, dtype)
396
norm_modes = (None, "forward", "backward", "ortho")
400
*product(range(2, 5), (None, (0,), (0, -1))),
407
fft_functions = [(torch.fft.fftn, torch.fft.ifftn)]
410
if not dtype.is_complex:
413
fft_functions += [(torch.fft.rfftn, torch.fft.irfftn),
414
(torch.fft.ihfftn, torch.fft.hfftn)]
416
for input_ndim, dim in transform_desc:
417
if dtype in (torch.half, torch.complex32):
419
shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)
421
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
422
x = torch.randn(*shape, device=device, dtype=dtype)
424
for (forward, backward), norm in product(fft_functions, norm_modes):
425
if isinstance(dim, tuple):
426
s = [x.size(d) for d in dim]
428
s = x.size() if dim is None else x.size(dim)
430
kwargs = {'s': s, 'dim': dim, 'norm': norm}
431
y = backward(forward(x, **kwargs), **kwargs)
433
if x.dtype is torch.half and y.dtype is torch.chalf:
436
self.assertEqual(x.to(torch.chalf), y)
438
self.assertEqual(x, y, exact_dtype=(
439
forward != torch.fft.fftn or x.is_complex()))
441
@onlyNativeDeviceTypes
442
@ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
443
allowed_dtypes=[torch.float, torch.cfloat])
444
def test_fftn_invalid(self, device, dtype, op):
445
a = torch.rand(10, 10, 10, device=device, dtype=dtype)
447
errMsg = "dims must be unique"
448
with self.assertRaisesRegex(RuntimeError, errMsg):
451
with self.assertRaisesRegex(RuntimeError, errMsg):
454
with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
455
op(a, s=(1,), dim=(0, 1))
457
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
460
with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"):
461
op(a, s=(10, 10, 10, 10))
464
@onlyNativeDeviceTypes
465
@dtypes(torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
466
def test_fftn_noop_transform(self, device, dtype):
467
skip_helper_for_fft(device, dtype)
469
torch.half: torch.chalf,
470
torch.float: torch.cfloat,
471
torch.double: torch.cdouble,
480
inp = make_tensor((10, 10), device=device, dtype=dtype)
481
out = torch.fft.fftn(inp, dim=[])
483
expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype)
484
expect = inp.to(expect_dtype)
485
self.assertEqual(expect, out)
489
@onlyNativeDeviceTypes
491
torch.half : tol(1e-2, 1e-2),
493
@dtypes(torch.half, torch.float, torch.double)
494
def test_hfftn(self, device, dtype):
495
skip_helper_for_fft(device, dtype)
499
*product(range(2, 5), (None, (0,), (0, -1))),
507
for input_ndim, dim in transform_desc:
508
actual_dims = list(range(input_ndim)) if dim is None else dim
509
if dtype is torch.half:
510
shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
512
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
513
expect = torch.randn(*shape, device=device, dtype=dtype)
514
input = torch.fft.ifftn(expect, dim=dim, norm="ortho")
516
lastdim = actual_dims[-1]
517
lastdim_size = input.size(lastdim) // 2 + 1
518
idx = [slice(None)] * input_ndim
519
idx[lastdim] = slice(0, lastdim_size)
522
s = [shape[dim] for dim in actual_dims]
523
actual = torch.fft.hfftn(input, s=s, dim=dim, norm="ortho")
525
self.assertEqual(expect, actual)
528
@onlyNativeDeviceTypes
530
torch.half : tol(1e-2, 1e-2),
532
@dtypes(torch.half, torch.float, torch.double)
533
def test_ihfftn(self, device, dtype):
534
skip_helper_for_fft(device, dtype)
538
*product(range(2, 5), (None, (0,), (0, -1))),
546
for input_ndim, dim in transform_desc:
547
if dtype is torch.half:
548
shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
550
shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
552
input = torch.randn(*shape, device=device, dtype=dtype)
553
expect = torch.fft.ifftn(input, dim=dim, norm="ortho")
556
lastdim = -1 if dim is None else dim[-1]
557
lastdim_size = expect.size(lastdim) // 2 + 1
558
idx = [slice(None)] * input_ndim
559
idx[lastdim] = slice(0, lastdim_size)
562
actual = torch.fft.ihfftn(input, dim=dim, norm="ortho")
563
self.assertEqual(expect, actual)
573
@onlyNativeDeviceTypes
574
@dtypes(torch.double, torch.complex128)
575
def test_fft2_numpy(self, device, dtype):
576
norm_modes = REFERENCE_NORM_MODES
580
*product(range(2, 5), (None, (4, 10))),
583
fft_functions = ['fft2', 'ifft2', 'irfft2', 'hfft2']
584
if dtype.is_floating_point:
585
fft_functions += ['rfft2', 'ihfft2']
587
for input_ndim, s in transform_desc:
588
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
589
input = torch.randn(*shape, device=device, dtype=dtype)
590
for fname, norm in product(fft_functions, norm_modes):
591
torch_fn = getattr(torch.fft, fname)
593
if not has_scipy_fft:
595
numpy_fn = getattr(scipy.fft, fname)
597
numpy_fn = getattr(np.fft, fname)
599
def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None):
600
return torch_fn(t, s, dim, norm)
602
torch_fns = (torch_fn, torch.jit.script(fn))
605
input_np = input.cpu().numpy()
606
expected = numpy_fn(input_np, s, norm=norm)
608
actual = fn(input, s, norm=norm)
609
self.assertEqual(actual, expected)
613
expected = numpy_fn(input_np, s, dim, norm)
615
actual = fn(input, s, dim, norm)
616
self.assertEqual(actual, expected)
619
@onlyNativeDeviceTypes
620
@dtypes(torch.float, torch.complex64)
621
def test_fft2_fftn_equivalence(self, device, dtype):
622
norm_modes = (None, "forward", "backward", "ortho")
626
*product(range(2, 5), (None, (4, 10)), (None, (1, 0))),
630
fft_functions = ['fft', 'ifft', 'irfft', 'hfft']
632
if dtype.is_floating_point:
633
fft_functions += ['rfft', 'ihfft']
635
for input_ndim, s, dim in transform_desc:
636
shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
637
x = torch.randn(*shape, device=device, dtype=dtype)
639
for func, norm in product(fft_functions, norm_modes):
640
f2d = getattr(torch.fft, func + '2')
641
fnd = getattr(torch.fft, func + 'n')
643
kwargs = {'s': s, 'norm': norm}
647
expect = fnd(x, **kwargs)
649
expect = fnd(x, dim=(-2, -1), **kwargs)
651
actual = f2d(x, **kwargs)
653
self.assertEqual(actual, expect)
656
@onlyNativeDeviceTypes
657
def test_fft2_invalid(self, device):
658
a = torch.rand(10, 10, 10, device=device)
659
fft_funcs = (torch.fft.fft2, torch.fft.ifft2,
660
torch.fft.rfft2, torch.fft.irfft2)
662
for func in fft_funcs:
663
with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
666
with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
669
with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
672
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
675
c = torch.complex(a, a)
676
with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"):
682
@onlyNativeDeviceTypes
683
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
684
@dtypes(torch.float, torch.double)
685
def test_fftfreq_numpy(self, device, dtype):
695
functions = ['fftfreq', 'rfftfreq']
697
for fname in functions:
698
torch_fn = getattr(torch.fft, fname)
699
numpy_fn = getattr(np.fft, fname)
701
for n, d in test_args:
702
args = (n,) if d is None else (n, d)
703
expected = numpy_fn(*args)
704
actual = torch_fn(*args, device=device, dtype=dtype)
705
self.assertEqual(actual, expected, exact_dtype=False)
708
@onlyNativeDeviceTypes
709
@dtypes(torch.float, torch.double)
710
def test_fftfreq_out(self, device, dtype):
711
for func in (torch.fft.fftfreq, torch.fft.rfftfreq):
712
expect = func(n=100, d=.5, device=device, dtype=dtype)
713
actual = torch.empty((), device=device, dtype=dtype)
714
with self.assertWarnsRegex(UserWarning, "out tensor will be resized"):
715
func(n=100, d=.5, out=actual)
716
self.assertEqual(actual, expect)
720
@onlyNativeDeviceTypes
721
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
722
@dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
723
def test_fftshift_numpy(self, device, dtype):
726
*product(((11,), (12,)), (None, 0, -1)),
727
*product(((4, 5), (6, 6)), (None, 0, (-1,))),
728
*product(((1, 1, 4, 6, 7, 2),), (None, (3, 4))),
731
functions = ['fftshift', 'ifftshift']
733
for shape, dim in test_args:
734
input = torch.rand(*shape, device=device, dtype=dtype)
735
input_np = input.cpu().numpy()
737
for fname in functions:
738
torch_fn = getattr(torch.fft, fname)
739
numpy_fn = getattr(np.fft, fname)
741
expected = numpy_fn(input_np, axes=dim)
742
actual = torch_fn(input, dim=dim)
743
self.assertEqual(actual, expected)
746
@onlyNativeDeviceTypes
747
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
748
@dtypes(torch.float, torch.double)
749
def test_fftshift_frequencies(self, device, dtype):
750
for n in range(10, 15):
751
sorted_fft_freqs = torch.arange(-(n // 2), n - (n // 2),
752
device=device, dtype=dtype)
753
x = torch.fft.fftfreq(n, d=1 / n, device=device, dtype=dtype)
756
shifted = torch.fft.fftshift(x)
757
self.assertEqual(shifted, shifted.sort().values)
758
self.assertEqual(sorted_fft_freqs, shifted)
761
self.assertEqual(x, torch.fft.ifftshift(shifted))
764
def _test_fft_ifft_rfft_irfft(self, device, dtype):
765
complex_dtype = corresponding_complex_dtype(dtype)
767
def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
768
x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device))
769
dim = tuple(range(-signal_ndim, 0))
770
for norm in ('ortho', None):
771
res = torch.fft.fftn(x, dim=dim, norm=norm)
772
rec = torch.fft.ifftn(res, dim=dim, norm=norm)
773
self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft')
774
res = torch.fft.ifftn(x, dim=dim, norm=norm)
775
rec = torch.fft.fftn(res, dim=dim, norm=norm)
776
self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft')
778
def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
779
x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device))
781
signal_sizes = x.size()[-signal_ndim:]
782
dim = tuple(range(-signal_ndim, 0))
783
for norm in (None, 'ortho'):
784
res = torch.fft.rfftn(x, dim=dim, norm=norm)
785
rec = torch.fft.irfftn(res, s=signal_sizes, dim=dim, norm=norm)
786
self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft')
787
res = torch.fft.fftn(x, dim=dim, norm=norm)
788
rec = torch.fft.ifftn(res, dim=dim, norm=norm)
789
x_complex = torch.complex(x, torch.zeros_like(x))
790
self.assertEqual(x_complex, rec, atol=1e-8, rtol=0, msg='fft and ifft (from real)')
793
_test_real((100,), 1)
794
_test_real((10, 1, 10, 100), 1)
795
_test_real((100, 100), 2)
796
_test_real((2, 2, 5, 80, 60), 2)
797
_test_real((50, 40, 70), 3)
798
_test_real((30, 1, 50, 25, 20), 3)
800
_test_complex((100,), 1)
801
_test_complex((100, 100), 1)
802
_test_complex((100, 100), 2)
803
_test_complex((1, 20, 80, 60), 2)
804
_test_complex((50, 40, 70), 3)
805
_test_complex((6, 5, 50, 25, 20), 3)
808
_test_real((165,), 1, lambda x: x.narrow(0, 25, 100))
809
_test_real((100, 100, 3), 1, lambda x: x[:, :, 0])
810
_test_real((100, 100), 2, lambda x: x.t())
811
_test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60])
812
_test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80])
813
_test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3))
815
_test_complex((100,), 1, lambda x: x.expand(100, 100))
816
_test_complex((20, 90, 110), 2, lambda x: x[:, 5:85].narrow(2, 5, 100))
817
_test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
818
_test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21])
821
@onlyNativeDeviceTypes
822
@dtypes(torch.double)
823
def test_fft_ifft_rfft_irfft(self, device, dtype):
824
self._test_fft_ifft_rfft_irfft(device, dtype)
826
@deviceCountAtLeast(1)
828
@dtypes(torch.double)
829
def test_cufft_plan_cache(self, devices, dtype):
831
def plan_cache_max_size(device, n):
833
plan_cache = torch.backends.cuda.cufft_plan_cache
835
plan_cache = torch.backends.cuda.cufft_plan_cache[device]
836
original = plan_cache.max_size
837
plan_cache.max_size = n
841
plan_cache.max_size = original
843
with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)):
844
self._test_fft_ifft_rfft_irfft(devices[0], dtype)
846
with plan_cache_max_size(devices[0], 0):
847
self._test_fft_ifft_rfft_irfft(devices[0], dtype)
849
torch.backends.cuda.cufft_plan_cache.clear()
852
with plan_cache_max_size(devices[0], 10):
853
self._test_fft_ifft_rfft_irfft(devices[0], dtype)
855
with self.assertRaisesRegex(RuntimeError, r"must be non-negative"):
856
torch.backends.cuda.cufft_plan_cache.max_size = -1
858
with self.assertRaisesRegex(RuntimeError, r"read-only property"):
859
torch.backends.cuda.cufft_plan_cache.size = -1
861
with self.assertRaisesRegex(RuntimeError, r"but got device with index"):
862
torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10]
867
x0 = torch.randn(2, 3, 3, device=devices[0])
868
x1 = x0.to(devices[1])
869
self.assertEqual(torch.fft.rfftn(x0, dim=(-2, -1)), torch.fft.rfftn(x1, dim=(-2, -1)))
878
with plan_cache_max_size(devices[0], 10):
879
with plan_cache_max_size(devices[1], 11):
880
self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
881
self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)
883
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)
884
with torch.cuda.device(devices[1]):
885
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)
886
with torch.cuda.device(devices[0]):
887
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)
889
self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
890
with torch.cuda.device(devices[1]):
891
with plan_cache_max_size(None, 11):
892
self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
893
self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)
895
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)
896
with torch.cuda.device(devices[0]):
897
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)
898
self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)
901
@dtypes(torch.cfloat, torch.cdouble)
902
def test_cufft_context(self, device, dtype):
904
x = torch.randn(32, dtype=dtype, device=device, requires_grad=True)
905
dout = torch.zeros(32, dtype=dtype, device=device)
908
out = torch.fft.ifft(torch.fft.fft(x))
909
out.backward(dout, retain_graph=True)
911
dx = torch.fft.fft(torch.fft.ifft(dout))
913
self.assertTrue((x.grad - dx).abs().max() == 0)
914
self.assertFalse((x.grad - x).abs().max() == 0)
917
@skipIfTorchDynamo("cannot set WRITEABLE flag to True of this array")
919
@onlyNativeDeviceTypes
920
@dtypes(torch.double)
921
def test_stft(self, device, dtype):
923
raise unittest.SkipTest('librosa not found')
925
def librosa_stft(x, n_fft, hop_length, win_length, window, center):
927
window = np.ones(n_fft if win_length is None else win_length)
929
window = window.cpu().numpy()
930
input_1d = x.dim() == 1
940
ri = librosa.stft(xi.cpu().numpy(), n_fft=n_fft, hop_length=hop_length,
941
win_length=win_length, window=window, center=center,
943
result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1)))
944
result = torch.stack(result, 0)
949
def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None,
950
center=True, expected_error=None):
951
x = torch.randn(*sizes, dtype=dtype, device=device)
952
if win_sizes is not None:
953
window = torch.randn(*win_sizes, dtype=dtype, device=device)
956
if expected_error is None:
957
result = x.stft(n_fft, hop_length, win_length, window,
958
center=center, return_complex=False)
961
ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
962
self.assertEqual(result, ref_result, atol=7e-6, rtol=0, msg='stft comparison against librosa', exact_dtype=False)
964
result_complex = x.stft(n_fft, hop_length, win_length, window, center=center, return_complex=True)
965
self.assertEqual(result_complex, torch.view_as_complex(result))
967
self.assertRaises(expected_error,
968
lambda: x.stft(n_fft, hop_length, win_length, window, center=center))
970
for center in [True, False]:
971
_test((10,), 7, center=center)
972
_test((10, 4000), 1024, center=center)
974
_test((10,), 7, 2, center=center)
975
_test((10, 4000), 1024, 512, center=center)
977
_test((10,), 7, 2, win_sizes=(7,), center=center)
978
_test((10, 4000), 1024, 512, win_sizes=(1024,), center=center)
981
_test((10,), 7, 2, win_length=5, center=center)
982
_test((10, 4000), 1024, 512, win_length=100, center=center)
984
_test((10, 4, 2), 1, 1, expected_error=RuntimeError)
985
_test((10,), 11, 1, center=False, expected_error=RuntimeError)
986
_test((10,), -1, 1, expected_error=RuntimeError)
987
_test((10,), 3, win_length=5, expected_error=RuntimeError)
988
_test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError)
989
_test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
991
@skipIfTorchDynamo("double")
993
@onlyNativeDeviceTypes
994
@dtypes(torch.double)
995
def test_istft_against_librosa(self, device, dtype):
997
raise unittest.SkipTest('librosa not found')
999
def librosa_istft(x, n_fft, hop_length, win_length, window, length, center):
1001
window = np.ones(n_fft if win_length is None else win_length)
1003
window = window.cpu().numpy()
1005
return librosa.istft(x.cpu().numpy(), n_fft=n_fft, hop_length=hop_length,
1006
win_length=win_length, length=length, window=window, center=center)
1008
def _test(size, n_fft, hop_length=None, win_length=None, win_sizes=None,
1009
length=None, center=True):
1010
x = torch.randn(size, dtype=dtype, device=device)
1011
if win_sizes is not None:
1012
window = torch.randn(*win_sizes, dtype=dtype, device=device)
1016
x_stft = x.stft(n_fft, hop_length, win_length, window, center=center,
1017
onesided=True, return_complex=True)
1019
ref_result = librosa_istft(x_stft, n_fft, hop_length, win_length,
1020
window, length, center)
1021
result = x_stft.istft(n_fft, hop_length, win_length, window,
1022
length=length, center=center)
1023
self.assertEqual(result, ref_result)
1025
for center in [True, False]:
1026
_test(10, 7, center=center)
1027
_test(4000, 1024, center=center)
1028
_test(4000, 1024, center=center, length=4000)
1030
_test(10, 7, 2, center=center)
1031
_test(4000, 1024, 512, center=center)
1032
_test(4000, 1024, 512, center=center, length=4000)
1034
_test(10, 7, 2, win_sizes=(7,), center=center)
1035
_test(4000, 1024, 512, win_sizes=(1024,), center=center)
1036
_test(4000, 1024, 512, win_sizes=(1024,), center=center, length=4000)
1038
@onlyNativeDeviceTypes
1040
@dtypes(torch.double, torch.cdouble)
1041
def test_complex_stft_roundtrip(self, device, dtype):
1042
test_args = list(product(
1044
(torch.randn(600, device=device, dtype=dtype),
1045
torch.randn(807, device=device, dtype=dtype),
1046
torch.randn(12, 60, device=device, dtype=dtype)),
1054
("constant", "reflect", "circular"),
1058
(True, False) if not dtype.is_complex else (False,),
1061
for args in test_args:
1062
x, n_fft, hop_length, center, pad_mode, normalized, onesided = args
1064
'n_fft': n_fft, 'hop_length': hop_length, 'center': center,
1065
'normalized': normalized, 'onesided': onesided,
1069
x_stft = torch.stft(x, pad_mode=pad_mode, return_complex=True, **common_kwargs)
1070
x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex,
1071
length=x.size(-1), **common_kwargs)
1072
self.assertEqual(x_roundtrip, x)
1075
x_stft = x.stft(pad_mode=pad_mode, return_complex=True, **common_kwargs)
1076
x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex,
1077
length=x.size(-1), **common_kwargs)
1078
self.assertEqual(x_roundtrip, x)
1080
@onlyNativeDeviceTypes
1082
@dtypes(torch.double, torch.cdouble)
1083
def test_stft_roundtrip_complex_window(self, device, dtype):
1084
test_args = list(product(
1086
(torch.randn(600, device=device, dtype=dtype),
1087
torch.randn(807, device=device, dtype=dtype),
1088
torch.randn(12, 60, device=device, dtype=dtype)),
1094
("constant", "reflect", "replicate", "circular"),
1098
for args in test_args:
1099
x, n_fft, hop_length, pad_mode, normalized = args
1100
window = torch.rand(n_fft, device=device, dtype=torch.cdouble)
1101
x_stft = torch.stft(
1102
x, n_fft=n_fft, hop_length=hop_length, window=window,
1103
center=True, pad_mode=pad_mode, normalized=normalized)
1104
self.assertEqual(x_stft.dtype, torch.cdouble)
1105
self.assertEqual(x_stft.size(-2), n_fft)
1107
x_roundtrip = torch.istft(
1108
x_stft, n_fft=n_fft, hop_length=hop_length, window=window,
1109
center=True, normalized=normalized, length=x.size(-1),
1110
return_complex=True)
1111
self.assertEqual(x_stft.dtype, torch.cdouble)
1113
if not dtype.is_complex:
1114
self.assertEqual(x_roundtrip.imag, torch.zeros_like(x_roundtrip.imag),
1116
self.assertEqual(x_roundtrip.real, x)
1118
self.assertEqual(x_roundtrip, x)
1122
@dtypes(torch.cdouble)
1123
def test_complex_stft_definition(self, device, dtype):
1124
test_args = list(product(
1126
(torch.randn(600, device=device, dtype=dtype),
1127
torch.randn(807, device=device, dtype=dtype)),
1134
for args in test_args:
1135
window = torch.randn(args[1], device=device, dtype=dtype)
1136
expected = _stft_reference(args[0], args[2], window)
1137
actual = torch.stft(*args, window=window, center=False)
1138
self.assertEqual(actual, expected)
1140
@onlyNativeDeviceTypes
1142
@dtypes(torch.cdouble)
1143
def test_complex_stft_real_equiv(self, device, dtype):
1144
test_args = list(product(
1146
(torch.rand(600, device=device, dtype=dtype),
1147
torch.rand(807, device=device, dtype=dtype),
1148
torch.rand(14, 50, device=device, dtype=dtype),
1149
torch.rand(6, 51, device=device, dtype=dtype)),
1159
("constant", "reflect", "circular"),
1164
for args in test_args:
1165
x, n_fft, hop_length, win_length, center, pad_mode, normalized = args
1166
expected = _complex_stft(x, n_fft, hop_length=hop_length,
1167
win_length=win_length, pad_mode=pad_mode,
1168
center=center, normalized=normalized)
1169
actual = torch.stft(x, n_fft, hop_length=hop_length,
1170
win_length=win_length, pad_mode=pad_mode,
1171
center=center, normalized=normalized)
1172
self.assertEqual(expected, actual)
1175
@dtypes(torch.cdouble)
1176
def test_complex_istft_real_equiv(self, device, dtype):
1177
test_args = list(product(
1179
(torch.rand(40, 20, device=device, dtype=dtype),
1180
torch.rand(25, 1, device=device, dtype=dtype),
1181
torch.rand(4, 20, 10, device=device, dtype=dtype)),
1190
for args in test_args:
1191
x, hop_length, center, normalized = args
1193
expected = _complex_istft(x, n_fft, hop_length=hop_length,
1194
center=center, normalized=normalized)
1195
actual = torch.istft(x, n_fft, hop_length=hop_length,
1196
center=center, normalized=normalized,
1197
return_complex=True)
1198
self.assertEqual(expected, actual)
1201
def test_complex_stft_onesided(self, device):
1203
for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2):
1204
x = torch.rand(100, device=device, dtype=x_dtype)
1205
window = torch.rand(10, device=device, dtype=window_dtype)
1207
if x_dtype.is_complex or window_dtype.is_complex:
1208
with self.assertRaisesRegex(RuntimeError, 'complex'):
1209
x.stft(10, window=window, pad_mode='constant', onesided=True)
1211
y = x.stft(10, window=window, pad_mode='constant', onesided=True,
1212
return_complex=True)
1213
self.assertEqual(y.dtype, torch.cdouble)
1214
self.assertEqual(y.size(), (6, 51))
1216
x = torch.rand(100, device=device, dtype=torch.cdouble)
1217
with self.assertRaisesRegex(RuntimeError, 'complex'):
1218
x.stft(10, pad_mode='constant', onesided=True)
1221
@onlyNativeDeviceTypes
1223
def test_stft_requires_complex(self, device):
1225
with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
1226
y = x.stft(10, pad_mode='constant')
1229
@onlyNativeDeviceTypes
1231
def test_stft_requires_window(self, device):
1233
with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"):
1234
y = x.stft(10, pad_mode='constant', return_complex=True)
1236
@onlyNativeDeviceTypes
1238
def test_istft_requires_window(self, device):
1239
stft = torch.rand((51, 5), dtype=torch.cdouble)
1241
with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"):
1242
x = torch.istft(stft, n_fft=100, length=100)
1245
def test_fft_input_modification(self, device):
1248
signal = torch.ones((2, 2, 2), device=device)
1249
signal_copy = signal.clone()
1250
spectrum = torch.fft.fftn(signal, dim=(-2, -1))
1251
self.assertEqual(signal, signal_copy)
1253
spectrum_copy = spectrum.clone()
1254
_ = torch.fft.ifftn(spectrum, dim=(-2, -1))
1255
self.assertEqual(spectrum, spectrum_copy)
1257
half_spectrum = torch.fft.rfftn(signal, dim=(-2, -1))
1258
self.assertEqual(signal, signal_copy)
1260
half_spectrum_copy = half_spectrum.clone()
1261
_ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1))
1262
self.assertEqual(half_spectrum, half_spectrum_copy)
1264
@onlyNativeDeviceTypes
1266
def test_fft_plan_repeatable(self, device):
1268
for n in [2048, 3199, 5999]:
1269
a = torch.randn(n, device=device, dtype=torch.complex64)
1270
res1 = torch.fft.fftn(a)
1271
res2 = torch.fft.fftn(a.clone())
1272
self.assertEqual(res1, res2)
1274
a = torch.randn(n, device=device, dtype=torch.float64)
1275
res1 = torch.fft.rfft(a)
1276
res2 = torch.fft.rfft(a.clone())
1277
self.assertEqual(res1, res2)
1279
@onlyNativeDeviceTypes
1281
@dtypes(torch.double)
1282
def test_istft_round_trip_simple_cases(self, device, dtype):
1283
"""stft -> istft should recover the original signale"""
1284
def _test(input, n_fft, length):
1285
stft = torch.stft(input, n_fft=n_fft, return_complex=True)
1286
inverse = torch.istft(stft, n_fft=n_fft, length=length)
1287
self.assertEqual(input, inverse, exact_dtype=True)
1289
_test(torch.ones(4, dtype=dtype, device=device), 4, 4)
1290
_test(torch.zeros(4, dtype=dtype, device=device), 4, 4)
1292
@onlyNativeDeviceTypes
1294
@dtypes(torch.double)
1295
def test_istft_round_trip_various_params(self, device, dtype):
1296
"""stft -> istft should recover the original signale"""
1297
def _test_istft_is_inverse_of_stft(stft_kwargs):
1300
data_sizes = [(2, 20), (3, 15), (4, 10)]
1302
istft_kwargs = stft_kwargs.copy()
1303
del istft_kwargs['pad_mode']
1304
for sizes in data_sizes:
1305
for i in range(num_trials):
1306
original = torch.randn(*sizes, dtype=dtype, device=device)
1307
stft = torch.stft(original, return_complex=True, **stft_kwargs)
1308
inversed = torch.istft(stft, length=original.size(1), **istft_kwargs)
1310
inversed, original, msg='istft comparison against original',
1311
atol=7e-6, rtol=0, exact_dtype=True)
1319
'window': torch.hann_window(12, dtype=dtype, device=device),
1321
'pad_mode': 'reflect',
1330
'window': torch.hann_window(8, dtype=dtype, device=device),
1332
'pad_mode': 'reflect',
1333
'normalized': False,
1341
'window': torch.hamming_window(11, dtype=dtype, device=device),
1343
'pad_mode': 'constant',
1353
'window': torch.hamming_window(5, dtype=dtype, device=device),
1355
'pad_mode': 'constant',
1356
'normalized': False,
1360
for i, pattern in enumerate(patterns):
1361
_test_istft_is_inverse_of_stft(pattern)
1363
@onlyNativeDeviceTypes
1365
@dtypes(torch.double)
1366
def test_istft_round_trip_with_padding(self, device, dtype):
1367
"""long hop_length or not centered may cause length mismatch in the inversed signal"""
1368
def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs):
1372
sizes = stft_kwargs['size']
1373
del stft_kwargs['size']
1374
istft_kwargs = stft_kwargs.copy()
1375
del istft_kwargs['pad_mode']
1376
for i in range(num_trials):
1377
original = torch.randn(*sizes, dtype=dtype, device=device)
1378
stft = torch.stft(original, return_complex=True, **stft_kwargs)
1379
with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."):
1380
inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs)
1381
n_frames = stft.size(-1)
1382
if stft_kwargs["center"] is True:
1383
len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1)
1385
len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1)
1387
padding = inversed[..., len_expected:]
1388
inversed = inversed[..., :len_expected]
1389
original = original[..., :len_expected]
1391
zeros = torch.zeros_like(padding, device=padding.device)
1393
padding, zeros, msg='istft padding values against zeros',
1394
atol=7e-6, rtol=0, exact_dtype=True)
1396
inversed, original, msg='istft comparison against original',
1397
atol=7e-6, rtol=0, exact_dtype=True)
1407
'window': torch.hamming_window(3, dtype=dtype, device=device),
1409
'pad_mode': 'reflect',
1410
'normalized': False,
1420
'window': torch.hamming_window(256, dtype=dtype, device=device),
1422
'pad_mode': 'constant',
1423
'normalized': False,
1427
for i, pattern in enumerate(patterns):
1428
_test_istft_is_inverse_of_stft_with_padding(pattern)
1430
@onlyNativeDeviceTypes
1431
def test_istft_throws(self, device):
1432
"""istft should throw exception for invalid parameters"""
1433
stft = torch.zeros((3, 5, 2), device=device)
1436
RuntimeError, torch.istft, stft, n_fft=4,
1437
hop_length=20, win_length=1, window=torch.ones(1))
1439
invalid_window = torch.zeros(4, device=device)
1441
RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window)
1443
self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2)
1444
self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2)
1446
@skipIfTorchDynamo("Failed running call_function")
1447
@onlyNativeDeviceTypes
1449
@dtypes(torch.double)
1450
def test_istft_of_sine(self, device, dtype):
1451
complex_dtype = corresponding_complex_dtype(dtype)
1453
def _test(amplitude, L, n):
1455
x = torch.arange(2 * L + 1, device=device, dtype=dtype)
1456
original = amplitude * torch.sin(2 * math.pi / L * x * n)
1459
stft = torch.zeros((L // 2 + 1, 2), device=device, dtype=complex_dtype)
1460
stft_largest_val = (amplitude * L) / 2.0
1461
if n < stft.size(0):
1462
stft[n].imag = torch.tensor(-stft_largest_val, dtype=dtype)
1464
if 0 <= L - n < stft.size(0):
1466
stft[L - n].imag = torch.tensor(stft_largest_val, dtype=dtype)
1468
inverse = torch.istft(
1469
stft, L, hop_length=L, win_length=L,
1470
window=torch.ones(L, device=device, dtype=dtype), center=False, normalized=False)
1472
original = original[..., :inverse.size(-1)]
1473
self.assertEqual(inverse, original, atol=1e-3, rtol=0)
1475
_test(amplitude=123, L=5, n=1)
1476
_test(amplitude=150, L=5, n=2)
1477
_test(amplitude=111, L=5, n=3)
1478
_test(amplitude=160, L=7, n=4)
1479
_test(amplitude=145, L=8, n=5)
1480
_test(amplitude=80, L=9, n=6)
1481
_test(amplitude=99, L=10, n=7)
1483
@onlyNativeDeviceTypes
1485
@dtypes(torch.double)
1486
def test_istft_linearity(self, device, dtype):
1488
complex_dtype = corresponding_complex_dtype(dtype)
1490
def _test(data_size, kwargs):
1491
for i in range(num_trials):
1492
tensor1 = torch.randn(data_size, device=device, dtype=complex_dtype)
1493
tensor2 = torch.randn(data_size, device=device, dtype=complex_dtype)
1494
a, b = torch.rand(2, dtype=dtype, device=device)
1496
istft1 = tensor1.istft(**kwargs)
1497
istft2 = tensor2.istft(**kwargs)
1498
istft = a * istft1 + b * istft2
1499
estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs)
1500
self.assertEqual(istft, estimate, atol=1e-5, rtol=0)
1507
'window': torch.hann_window(12, device=device, dtype=dtype),
1518
'window': torch.hann_window(12, device=device, dtype=dtype),
1520
'normalized': False,
1529
'window': torch.hamming_window(12, device=device, dtype=dtype),
1540
'window': torch.hamming_window(12, device=device, dtype=dtype),
1542
'normalized': False,
1547
for data_size, kwargs in patterns:
1548
_test(data_size, kwargs)
1550
@onlyNativeDeviceTypes
1552
def test_batch_istft(self, device):
1553
original = torch.tensor([
1554
[4., 4., 4., 4., 4.],
1555
[0., 0., 0., 0., 0.],
1556
[0., 0., 0., 0., 0.]
1557
], device=device, dtype=torch.complex64)
1559
single = original.repeat(1, 1, 1)
1560
multi = original.repeat(4, 1, 1)
1562
i_original = torch.istft(original, n_fft=4, length=4)
1563
i_single = torch.istft(single, n_fft=4, length=4)
1564
i_multi = torch.istft(multi, n_fft=4, length=4)
1566
self.assertEqual(i_original.repeat(1, 1), i_single, atol=1e-6, rtol=0, exact_dtype=True)
1567
self.assertEqual(i_original.repeat(4, 1), i_multi, atol=1e-6, rtol=0, exact_dtype=True)
1570
@skipIf(not TEST_MKL, "Test requires MKL")
1571
def test_stft_window_device(self, device):
1573
x = torch.randn(1000, dtype=torch.complex64)
1574
window = torch.randn(100, dtype=torch.complex64)
1576
with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"):
1577
torch.stft(x, n_fft=100, window=window.to(device))
1579
with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"):
1580
torch.stft(x.to(device), n_fft=100, window=window)
1582
X = torch.stft(x, n_fft=100, window=window)
1584
with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"):
1585
torch.istft(X, n_fft=100, window=window.to(device))
1587
with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"):
1588
torch.istft(x.to(device), n_fft=100, window=window)
1591
class FFTDocTestFinder:
1592
'''The default doctest finder doesn't like that function.__module__ doesn't
1593
match torch.fft. It assumes the functions are leaked imports.
1596
self.parser = doctest.DocTestParser()
1598
def find(self, obj, name=None, module=None, globs=None, extraglobs=None):
1601
modname = name if name is not None else obj.__name__
1602
globs = {} if globs is None else globs
1604
for fname in obj.__all__:
1605
func = getattr(obj, fname)
1606
if inspect.isroutine(func):
1607
qualname = modname + '.' + fname
1608
docstring = inspect.getdoc(func)
1609
if docstring is None:
1612
examples = self.parser.get_doctest(
1613
docstring, globs=globs, name=fname, filename=None, lineno=None)
1614
doctests.append(examples)
1619
class TestFFTDocExamples(TestCase):
1622
def generate_doc_test(doc_test):
1623
def test(self, device):
1624
self.assertEqual(device, 'cpu')
1625
runner = doctest.DocTestRunner()
1626
runner.run(doc_test)
1628
if runner.failures != 0:
1630
self.fail('Doctest failed')
1632
setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test))
1634
for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)):
1635
generate_doc_test(doc_test)
1638
instantiate_device_type_tests(TestFFT, globals())
1639
instantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu')
1641
if __name__ == '__main__':