pytorch

Форк
0
/
test_spectral_ops.py 
1642 строки · 65.0 Кб
1
# Owner(s): ["module: fft"]
2

3
import torch
4
import unittest
5
import math
6
from contextlib import contextmanager
7
from itertools import product
8
import itertools
9
import doctest
10
import inspect
11

12
from torch.testing._internal.common_utils import \
13
    (TestCase, run_tests, TEST_NUMPY, TEST_LIBROSA, TEST_MKL, first_sample, TEST_WITH_ROCM,
14
     make_tensor, skipIfTorchDynamo)
15
from torch.testing._internal.common_device_type import \
16
    (instantiate_device_type_tests, ops, dtypes, onlyNativeDeviceTypes,
17
     skipCPUIfNoFFT, deviceCountAtLeast, onlyCUDA, OpDTypes, skipIf, toleranceOverride, tol)
18
from torch.testing._internal.common_methods_invocations import (
19
    spectral_funcs, SpectralFuncType)
20
from torch.testing._internal.common_cuda import SM53OrLater
21
from torch._prims_common import corresponding_complex_dtype
22

23
from typing import Optional, List
24
from packaging import version
25

26

27
if TEST_NUMPY:
28
    import numpy as np
29

30

31
if TEST_LIBROSA:
32
    import librosa
33

34
has_scipy_fft = False
35
try:
36
    import scipy.fft
37
    has_scipy_fft = True
38
except ModuleNotFoundError:
39
    pass
40

41
REFERENCE_NORM_MODES = (
42
    (None, "forward", "backward", "ortho")
43
    if version.parse(np.__version__) >= version.parse('1.20.0') and (
44
        not has_scipy_fft or version.parse(scipy.__version__) >= version.parse('1.6.0'))
45
    else (None, "ortho"))
46

47

48
def _complex_stft(x, *args, **kwargs):
49
    # Transform real and imaginary components separably
50
    stft_real = torch.stft(x.real, *args, **kwargs, return_complex=True, onesided=False)
51
    stft_imag = torch.stft(x.imag, *args, **kwargs, return_complex=True, onesided=False)
52
    return stft_real + 1j * stft_imag
53

54

55
def _hermitian_conj(x, dim):
56
    """Returns the hermitian conjugate along a single dimension
57

58
    H(x)[i] = conj(x[-i])
59
    """
60
    out = torch.empty_like(x)
61
    mid = (x.size(dim) - 1) // 2
62
    idx = [slice(None)] * out.dim()
63
    idx_center = list(idx)
64
    idx_center[dim] = 0
65
    out[idx] = x[idx]
66

67
    idx_neg = list(idx)
68
    idx_neg[dim] = slice(-mid, None)
69
    idx_pos = idx
70
    idx_pos[dim] = slice(1, mid + 1)
71

72
    out[idx_pos] = x[idx_neg].flip(dim)
73
    out[idx_neg] = x[idx_pos].flip(dim)
74
    if (2 * mid + 1 < x.size(dim)):
75
        idx[dim] = mid + 1
76
        out[idx] = x[idx]
77
    return out.conj()
78

79

80
def _complex_istft(x, *args, **kwargs):
81
    # Decompose into Hermitian (FFT of real) and anti-Hermitian (FFT of imaginary)
82
    n_fft = x.size(-2)
83
    slc = (Ellipsis, slice(None, n_fft // 2 + 1), slice(None))
84

85
    hconj = _hermitian_conj(x, dim=-2)
86
    x_hermitian = (x + hconj) / 2
87
    x_antihermitian = (x - hconj) / 2
88
    istft_real = torch.istft(x_hermitian[slc], *args, **kwargs, onesided=True)
89
    istft_imag = torch.istft(-1j * x_antihermitian[slc], *args, **kwargs, onesided=True)
90
    return torch.complex(istft_real, istft_imag)
91

92

93
def _stft_reference(x, hop_length, window):
94
    r"""Reference stft implementation
95

96
    This doesn't implement all of torch.stft, only the STFT definition:
97

98
    .. math:: X(m, \omega) = \sum_n x[n]w[n - m] e^{-jn\omega}
99

100
    """
101
    n_fft = window.numel()
102
    X = torch.empty((n_fft, (x.numel() - n_fft + hop_length) // hop_length),
103
                    device=x.device, dtype=torch.cdouble)
104
    for m in range(X.size(1)):
105
        start = m * hop_length
106
        if start + n_fft > x.numel():
107
            slc = torch.empty(n_fft, device=x.device, dtype=x.dtype)
108
            tmp = x[start:]
109
            slc[:tmp.numel()] = tmp
110
        else:
111
            slc = x[start: start + n_fft]
112
        X[:, m] = torch.fft.fft(slc * window)
113
    return X
114

115

116
def skip_helper_for_fft(device, dtype):
117
    device_type = torch.device(device).type
118
    if dtype not in (torch.half, torch.complex32):
119
        return
120

121
    if device_type == 'cpu':
122
        raise unittest.SkipTest("half and complex32 are not supported on CPU")
123
    if not SM53OrLater:
124
        raise unittest.SkipTest("half and complex32 are only supported on CUDA device with SM>53")
125

126

127
# Tests of functions related to Fourier analysis in the torch.fft namespace
128
class TestFFT(TestCase):
129
    exact_dtype = True
130

131
    @onlyNativeDeviceTypes
132
    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.OneD],
133
         allowed_dtypes=(torch.float, torch.cfloat))
134
    def test_reference_1d(self, device, dtype, op):
135
        if op.ref is None:
136
            raise unittest.SkipTest("No reference implementation")
137

138
        norm_modes = REFERENCE_NORM_MODES
139
        test_args = [
140
            *product(
141
                # input
142
                (torch.randn(67, device=device, dtype=dtype),
143
                 torch.randn(80, device=device, dtype=dtype),
144
                 torch.randn(12, 14, device=device, dtype=dtype),
145
                 torch.randn(9, 6, 3, device=device, dtype=dtype)),
146
                # n
147
                (None, 50, 6),
148
                # dim
149
                (-1, 0),
150
                # norm
151
                norm_modes
152
            ),
153
            # Test transforming middle dimensions of multi-dim tensor
154
            *product(
155
                (torch.randn(4, 5, 6, 7, device=device, dtype=dtype),),
156
                (None,),
157
                (1, 2, -2,),
158
                norm_modes
159
            )
160
        ]
161

162
        for iargs in test_args:
163
            args = list(iargs)
164
            input = args[0]
165
            args = args[1:]
166

167
            expected = op.ref(input.cpu().numpy(), *args)
168
            exact_dtype = dtype in (torch.double, torch.complex128)
169
            actual = op(input, *args)
170
            self.assertEqual(actual, expected, exact_dtype=exact_dtype)
171

172
    @skipCPUIfNoFFT
173
    @onlyNativeDeviceTypes
174
    @toleranceOverride({
175
        torch.half : tol(1e-2, 1e-2),
176
        torch.chalf : tol(1e-2, 1e-2),
177
    })
178
    @dtypes(torch.half, torch.float, torch.double, torch.complex32, torch.complex64, torch.complex128)
179
    def test_fft_round_trip(self, device, dtype):
180
        skip_helper_for_fft(device, dtype)
181
        # Test that round trip through ifft(fft(x)) is the identity
182
        if dtype not in (torch.half, torch.complex32):
183
            test_args = list(product(
184
                # input
185
                (torch.randn(67, device=device, dtype=dtype),
186
                 torch.randn(80, device=device, dtype=dtype),
187
                 torch.randn(12, 14, device=device, dtype=dtype),
188
                 torch.randn(9, 6, 3, device=device, dtype=dtype)),
189
                # dim
190
                (-1, 0),
191
                # norm
192
                (None, "forward", "backward", "ortho")
193
            ))
194
        else:
195
            # cuFFT supports powers of 2 for half and complex half precision
196
            test_args = list(product(
197
                # input
198
                (torch.randn(64, device=device, dtype=dtype),
199
                 torch.randn(128, device=device, dtype=dtype),
200
                 torch.randn(4, 16, device=device, dtype=dtype),
201
                 torch.randn(8, 6, 2, device=device, dtype=dtype)),
202
                # dim
203
                (-1, 0),
204
                # norm
205
                (None, "forward", "backward", "ortho")
206
            ))
207

208
        fft_functions = [(torch.fft.fft, torch.fft.ifft)]
209
        # Real-only functions
210
        if not dtype.is_complex:
211
            # NOTE: Using ihfft as "forward" transform to avoid needing to
212
            # generate true half-complex input
213
            fft_functions += [(torch.fft.rfft, torch.fft.irfft),
214
                              (torch.fft.ihfft, torch.fft.hfft)]
215

216
        for forward, backward in fft_functions:
217
            for x, dim, norm in test_args:
218
                kwargs = {
219
                    'n': x.size(dim),
220
                    'dim': dim,
221
                    'norm': norm,
222
                }
223

224
                y = backward(forward(x, **kwargs), **kwargs)
225
                if x.dtype is torch.half and y.dtype is torch.complex32:
226
                    # Since type promotion currently doesn't work with complex32
227
                    # manually promote `x` to complex32
228
                    x = x.to(torch.complex32)
229
                # For real input, ifft(fft(x)) will convert to complex
230
                self.assertEqual(x, y, exact_dtype=(
231
                    forward != torch.fft.fft or x.is_complex()))
232

233
    # Note: NumPy will throw a ValueError for an empty input
234
    @onlyNativeDeviceTypes
235
    @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.float, torch.complex32, torch.cfloat))
236
    def test_empty_fft(self, device, dtype, op):
237
        t = torch.empty(1, 0, device=device, dtype=dtype)
238
        match = r"Invalid number of data points \([-\d]*\) specified"
239

240
        with self.assertRaisesRegex(RuntimeError, match):
241
            op(t)
242

243
    @onlyNativeDeviceTypes
244
    def test_empty_ifft(self, device):
245
        t = torch.empty(2, 1, device=device, dtype=torch.complex64)
246
        match = r"Invalid number of data points \([-\d]*\) specified"
247

248
        for f in [torch.fft.irfft, torch.fft.irfft2, torch.fft.irfftn,
249
                  torch.fft.hfft, torch.fft.hfft2, torch.fft.hfftn]:
250
            with self.assertRaisesRegex(RuntimeError, match):
251
                f(t)
252

253
    @onlyNativeDeviceTypes
254
    def test_fft_invalid_dtypes(self, device):
255
        t = torch.randn(64, device=device, dtype=torch.complex128)
256

257
        with self.assertRaisesRegex(RuntimeError, "rfft expects a real input tensor"):
258
            torch.fft.rfft(t)
259

260
        with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input tensor"):
261
            torch.fft.rfftn(t)
262

263
        with self.assertRaisesRegex(RuntimeError, "ihfft expects a real input tensor"):
264
            torch.fft.ihfft(t)
265

266
    @skipCPUIfNoFFT
267
    @onlyNativeDeviceTypes
268
    @dtypes(torch.int8, torch.half, torch.float, torch.double,
269
            torch.complex32, torch.complex64, torch.complex128)
270
    def test_fft_type_promotion(self, device, dtype):
271
        skip_helper_for_fft(device, dtype)
272

273
        if dtype.is_complex or dtype.is_floating_point:
274
            t = torch.randn(64, device=device, dtype=dtype)
275
        else:
276
            t = torch.randint(-2, 2, (64,), device=device, dtype=dtype)
277

278
        PROMOTION_MAP = {
279
            torch.int8: torch.complex64,
280
            torch.half: torch.complex32,
281
            torch.float: torch.complex64,
282
            torch.double: torch.complex128,
283
            torch.complex32: torch.complex32,
284
            torch.complex64: torch.complex64,
285
            torch.complex128: torch.complex128,
286
        }
287
        T = torch.fft.fft(t)
288
        self.assertEqual(T.dtype, PROMOTION_MAP[dtype])
289

290
        PROMOTION_MAP_C2R = {
291
            torch.int8: torch.float,
292
            torch.half: torch.half,
293
            torch.float: torch.float,
294
            torch.double: torch.double,
295
            torch.complex32: torch.half,
296
            torch.complex64: torch.float,
297
            torch.complex128: torch.double,
298
        }
299
        if dtype in (torch.half, torch.complex32):
300
            # cuFFT supports powers of 2 for half and complex half precision
301
            # NOTE: With hfft and default args where output_size n=2*(input_size - 1),
302
            # we make sure that logical fft size is a power of two.
303
            x = torch.randn(65, device=device, dtype=dtype)
304
            R = torch.fft.hfft(x)
305
        else:
306
            R = torch.fft.hfft(t)
307
        self.assertEqual(R.dtype, PROMOTION_MAP_C2R[dtype])
308

309
        if not dtype.is_complex:
310
            PROMOTION_MAP_R2C = {
311
                torch.int8: torch.complex64,
312
                torch.half: torch.complex32,
313
                torch.float: torch.complex64,
314
                torch.double: torch.complex128,
315
            }
316
            C = torch.fft.rfft(t)
317
            self.assertEqual(C.dtype, PROMOTION_MAP_R2C[dtype])
318

319
    @onlyNativeDeviceTypes
320
    @ops(spectral_funcs, dtypes=OpDTypes.unsupported,
321
         allowed_dtypes=[torch.half, torch.bfloat16])
322
    def test_fft_half_and_bfloat16_errors(self, device, dtype, op):
323
        # TODO: Remove torch.half error when complex32 is fully implemented
324
        sample = first_sample(self, op.sample_inputs(device, dtype))
325
        device_type = torch.device(device).type
326
        default_msg = "Unsupported dtype"
327
        if dtype is torch.half and device_type == 'cuda' and TEST_WITH_ROCM:
328
            err_msg = default_msg
329
        elif dtype is torch.half and device_type == 'cuda' and not SM53OrLater:
330
            err_msg = "cuFFT doesn't support signals of half type with compute capability less than SM_53"
331
        else:
332
            err_msg = default_msg
333
        with self.assertRaisesRegex(RuntimeError, err_msg):
334
            op(sample.input, *sample.args, **sample.kwargs)
335

336
    @onlyNativeDeviceTypes
337
    @ops(spectral_funcs, allowed_dtypes=(torch.half, torch.chalf))
338
    def test_fft_half_and_chalf_not_power_of_two_error(self, device, dtype, op):
339
        t = make_tensor(13, 13, device=device, dtype=dtype)
340
        err_msg = "cuFFT only supports dimensions whose sizes are powers of two"
341
        with self.assertRaisesRegex(RuntimeError, err_msg):
342
            op(t)
343

344
        if op.ndimensional in (SpectralFuncType.ND, SpectralFuncType.TwoD):
345
            kwargs = {'s': (12, 12)}
346
        else:
347
            kwargs = {'n': 12}
348

349
        with self.assertRaisesRegex(RuntimeError, err_msg):
350
            op(t, **kwargs)
351

352
    # nd-fft tests
353
    @onlyNativeDeviceTypes
354
    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
355
    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
356
         allowed_dtypes=(torch.cfloat, torch.cdouble))
357
    def test_reference_nd(self, device, dtype, op):
358
        if op.ref is None:
359
            raise unittest.SkipTest("No reference implementation")
360

361
        norm_modes = REFERENCE_NORM_MODES
362

363
        # input_ndim, s, dim
364
        transform_desc = [
365
            *product(range(2, 5), (None,), (None, (0,), (0, -1))),
366
            *product(range(2, 5), (None, (4, 10)), (None,)),
367
            (6, None, None),
368
            (5, None, (1, 3, 4)),
369
            (3, None, (1,)),
370
            (1, None, (0,)),
371
            (4, (10, 10), None),
372
            (4, (10, 10), (0, 1))
373
        ]
374

375
        for input_ndim, s, dim in transform_desc:
376
            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
377
            input = torch.randn(*shape, device=device, dtype=dtype)
378

379
            for norm in norm_modes:
380
                expected = op.ref(input.cpu().numpy(), s, dim, norm)
381
                exact_dtype = dtype in (torch.double, torch.complex128)
382
                actual = op(input, s, dim, norm)
383
                self.assertEqual(actual, expected, exact_dtype=exact_dtype)
384

385
    @skipCPUIfNoFFT
386
    @onlyNativeDeviceTypes
387
    @toleranceOverride({
388
        torch.half : tol(1e-2, 1e-2),
389
        torch.chalf : tol(1e-2, 1e-2),
390
    })
391
    @dtypes(torch.half, torch.float, torch.double,
392
            torch.complex32, torch.complex64, torch.complex128)
393
    def test_fftn_round_trip(self, device, dtype):
394
        skip_helper_for_fft(device, dtype)
395

396
        norm_modes = (None, "forward", "backward", "ortho")
397

398
        # input_ndim, dim
399
        transform_desc = [
400
            *product(range(2, 5), (None, (0,), (0, -1))),
401
            (7, None),
402
            (5, (1, 3, 4)),
403
            (3, (1,)),
404
            (1, 0),
405
        ]
406

407
        fft_functions = [(torch.fft.fftn, torch.fft.ifftn)]
408

409
        # Real-only functions
410
        if not dtype.is_complex:
411
            # NOTE: Using ihfftn as "forward" transform to avoid needing to
412
            # generate true half-complex input
413
            fft_functions += [(torch.fft.rfftn, torch.fft.irfftn),
414
                              (torch.fft.ihfftn, torch.fft.hfftn)]
415

416
        for input_ndim, dim in transform_desc:
417
            if dtype in (torch.half, torch.complex32):
418
                # cuFFT supports powers of 2 for half and complex half precision
419
                shape = itertools.islice(itertools.cycle((2, 4, 8)), input_ndim)
420
            else:
421
                shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
422
            x = torch.randn(*shape, device=device, dtype=dtype)
423

424
            for (forward, backward), norm in product(fft_functions, norm_modes):
425
                if isinstance(dim, tuple):
426
                    s = [x.size(d) for d in dim]
427
                else:
428
                    s = x.size() if dim is None else x.size(dim)
429

430
                kwargs = {'s': s, 'dim': dim, 'norm': norm}
431
                y = backward(forward(x, **kwargs), **kwargs)
432
                # For real input, ifftn(fftn(x)) will convert to complex
433
                if x.dtype is torch.half and y.dtype is torch.chalf:
434
                    # Since type promotion currently doesn't work with complex32
435
                    # manually promote `x` to complex32
436
                    self.assertEqual(x.to(torch.chalf), y)
437
                else:
438
                    self.assertEqual(x, y, exact_dtype=(
439
                        forward != torch.fft.fftn or x.is_complex()))
440

441
    @onlyNativeDeviceTypes
442
    @ops([op for op in spectral_funcs if op.ndimensional == SpectralFuncType.ND],
443
         allowed_dtypes=[torch.float, torch.cfloat])
444
    def test_fftn_invalid(self, device, dtype, op):
445
        a = torch.rand(10, 10, 10, device=device, dtype=dtype)
446
        # FIXME: https://github.com/pytorch/pytorch/issues/108205
447
        errMsg = "dims must be unique"
448
        with self.assertRaisesRegex(RuntimeError, errMsg):
449
            op(a, dim=(0, 1, 0))
450

451
        with self.assertRaisesRegex(RuntimeError, errMsg):
452
            op(a, dim=(2, -1))
453

454
        with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
455
            op(a, s=(1,), dim=(0, 1))
456

457
        with self.assertRaisesRegex(IndexError, "Dimension out of range"):
458
            op(a, dim=(3,))
459

460
        with self.assertRaisesRegex(RuntimeError, "tensor only has 3 dimensions"):
461
            op(a, s=(10, 10, 10, 10))
462

463
    @skipCPUIfNoFFT
464
    @onlyNativeDeviceTypes
465
    @dtypes(torch.half, torch.float, torch.double, torch.cfloat, torch.cdouble)
466
    def test_fftn_noop_transform(self, device, dtype):
467
        skip_helper_for_fft(device, dtype)
468
        RESULT_TYPE = {
469
            torch.half: torch.chalf,
470
            torch.float: torch.cfloat,
471
            torch.double: torch.cdouble,
472
        }
473

474
        for op in [
475
            torch.fft.fftn,
476
            torch.fft.ifftn,
477
            torch.fft.fft2,
478
            torch.fft.ifft2,
479
        ]:
480
            inp = make_tensor((10, 10), device=device, dtype=dtype)
481
            out = torch.fft.fftn(inp, dim=[])
482

483
            expect_dtype = RESULT_TYPE.get(inp.dtype, inp.dtype)
484
            expect = inp.to(expect_dtype)
485
            self.assertEqual(expect, out)
486

487

488
    @skipCPUIfNoFFT
489
    @onlyNativeDeviceTypes
490
    @toleranceOverride({
491
        torch.half : tol(1e-2, 1e-2),
492
    })
493
    @dtypes(torch.half, torch.float, torch.double)
494
    def test_hfftn(self, device, dtype):
495
        skip_helper_for_fft(device, dtype)
496

497
        # input_ndim, dim
498
        transform_desc = [
499
            *product(range(2, 5), (None, (0,), (0, -1))),
500
            (6, None),
501
            (5, (1, 3, 4)),
502
            (3, (1,)),
503
            (1, (0,)),
504
            (4, (0, 1))
505
        ]
506

507
        for input_ndim, dim in transform_desc:
508
            actual_dims = list(range(input_ndim)) if dim is None else dim
509
            if dtype is torch.half:
510
                shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
511
            else:
512
                shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
513
            expect = torch.randn(*shape, device=device, dtype=dtype)
514
            input = torch.fft.ifftn(expect, dim=dim, norm="ortho")
515

516
            lastdim = actual_dims[-1]
517
            lastdim_size = input.size(lastdim) // 2 + 1
518
            idx = [slice(None)] * input_ndim
519
            idx[lastdim] = slice(0, lastdim_size)
520
            input = input[idx]
521

522
            s = [shape[dim] for dim in actual_dims]
523
            actual = torch.fft.hfftn(input, s=s, dim=dim, norm="ortho")
524

525
            self.assertEqual(expect, actual)
526

527
    @skipCPUIfNoFFT
528
    @onlyNativeDeviceTypes
529
    @toleranceOverride({
530
        torch.half : tol(1e-2, 1e-2),
531
    })
532
    @dtypes(torch.half, torch.float, torch.double)
533
    def test_ihfftn(self, device, dtype):
534
        skip_helper_for_fft(device, dtype)
535

536
        # input_ndim, dim
537
        transform_desc = [
538
            *product(range(2, 5), (None, (0,), (0, -1))),
539
            (6, None),
540
            (5, (1, 3, 4)),
541
            (3, (1,)),
542
            (1, (0,)),
543
            (4, (0, 1))
544
        ]
545

546
        for input_ndim, dim in transform_desc:
547
            if dtype is torch.half:
548
                shape = tuple(itertools.islice(itertools.cycle((2, 4, 8)), input_ndim))
549
            else:
550
                shape = tuple(itertools.islice(itertools.cycle(range(4, 9)), input_ndim))
551

552
            input = torch.randn(*shape, device=device, dtype=dtype)
553
            expect = torch.fft.ifftn(input, dim=dim, norm="ortho")
554

555
            # Slice off the half-symmetric component
556
            lastdim = -1 if dim is None else dim[-1]
557
            lastdim_size = expect.size(lastdim) // 2 + 1
558
            idx = [slice(None)] * input_ndim
559
            idx[lastdim] = slice(0, lastdim_size)
560
            expect = expect[idx]
561

562
            actual = torch.fft.ihfftn(input, dim=dim, norm="ortho")
563
            self.assertEqual(expect, actual)
564

565

566
    # 2d-fft tests
567

568
    # NOTE: 2d transforms are only thin wrappers over n-dim transforms,
569
    # so don't require exhaustive testing.
570

571

572
    @skipCPUIfNoFFT
573
    @onlyNativeDeviceTypes
574
    @dtypes(torch.double, torch.complex128)
575
    def test_fft2_numpy(self, device, dtype):
576
        norm_modes = REFERENCE_NORM_MODES
577

578
        # input_ndim, s
579
        transform_desc = [
580
            *product(range(2, 5), (None, (4, 10))),
581
        ]
582

583
        fft_functions = ['fft2', 'ifft2', 'irfft2', 'hfft2']
584
        if dtype.is_floating_point:
585
            fft_functions += ['rfft2', 'ihfft2']
586

587
        for input_ndim, s in transform_desc:
588
            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
589
            input = torch.randn(*shape, device=device, dtype=dtype)
590
            for fname, norm in product(fft_functions, norm_modes):
591
                torch_fn = getattr(torch.fft, fname)
592
                if "hfft" in fname:
593
                    if not has_scipy_fft:
594
                        continue  # Requires scipy to compare against
595
                    numpy_fn = getattr(scipy.fft, fname)
596
                else:
597
                    numpy_fn = getattr(np.fft, fname)
598

599
                def fn(t: torch.Tensor, s: Optional[List[int]], dim: List[int] = (-2, -1), norm: Optional[str] = None):
600
                    return torch_fn(t, s, dim, norm)
601

602
                torch_fns = (torch_fn, torch.jit.script(fn))
603

604
                # Once with dim defaulted
605
                input_np = input.cpu().numpy()
606
                expected = numpy_fn(input_np, s, norm=norm)
607
                for fn in torch_fns:
608
                    actual = fn(input, s, norm=norm)
609
                    self.assertEqual(actual, expected)
610

611
                # Once with explicit dims
612
                dim = (1, 0)
613
                expected = numpy_fn(input_np, s, dim, norm)
614
                for fn in torch_fns:
615
                    actual = fn(input, s, dim, norm)
616
                    self.assertEqual(actual, expected)
617

618
    @skipCPUIfNoFFT
619
    @onlyNativeDeviceTypes
620
    @dtypes(torch.float, torch.complex64)
621
    def test_fft2_fftn_equivalence(self, device, dtype):
622
        norm_modes = (None, "forward", "backward", "ortho")
623

624
        # input_ndim, s, dim
625
        transform_desc = [
626
            *product(range(2, 5), (None, (4, 10)), (None, (1, 0))),
627
            (3, None, (0, 2)),
628
        ]
629

630
        fft_functions = ['fft', 'ifft', 'irfft', 'hfft']
631
        # Real-only functions
632
        if dtype.is_floating_point:
633
            fft_functions += ['rfft', 'ihfft']
634

635
        for input_ndim, s, dim in transform_desc:
636
            shape = itertools.islice(itertools.cycle(range(4, 9)), input_ndim)
637
            x = torch.randn(*shape, device=device, dtype=dtype)
638

639
            for func, norm in product(fft_functions, norm_modes):
640
                f2d = getattr(torch.fft, func + '2')
641
                fnd = getattr(torch.fft, func + 'n')
642

643
                kwargs = {'s': s, 'norm': norm}
644

645
                if dim is not None:
646
                    kwargs['dim'] = dim
647
                    expect = fnd(x, **kwargs)
648
                else:
649
                    expect = fnd(x, dim=(-2, -1), **kwargs)
650

651
                actual = f2d(x, **kwargs)
652

653
                self.assertEqual(actual, expect)
654

655
    @skipCPUIfNoFFT
656
    @onlyNativeDeviceTypes
657
    def test_fft2_invalid(self, device):
658
        a = torch.rand(10, 10, 10, device=device)
659
        fft_funcs = (torch.fft.fft2, torch.fft.ifft2,
660
                     torch.fft.rfft2, torch.fft.irfft2)
661

662
        for func in fft_funcs:
663
            with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
664
                func(a, dim=(0, 0))
665

666
            with self.assertRaisesRegex(RuntimeError, "dims must be unique"):
667
                func(a, dim=(2, -1))
668

669
            with self.assertRaisesRegex(RuntimeError, "dim and shape .* same length"):
670
                func(a, s=(1,))
671

672
            with self.assertRaisesRegex(IndexError, "Dimension out of range"):
673
                func(a, dim=(2, 3))
674

675
        c = torch.complex(a, a)
676
        with self.assertRaisesRegex(RuntimeError, "rfftn expects a real-valued input"):
677
            torch.fft.rfft2(c)
678

679
    # Helper functions
680

681
    @skipCPUIfNoFFT
682
    @onlyNativeDeviceTypes
683
    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
684
    @dtypes(torch.float, torch.double)
685
    def test_fftfreq_numpy(self, device, dtype):
686
        test_args = [
687
            *product(
688
                # n
689
                range(1, 20),
690
                # d
691
                (None, 10.0),
692
            )
693
        ]
694

695
        functions = ['fftfreq', 'rfftfreq']
696

697
        for fname in functions:
698
            torch_fn = getattr(torch.fft, fname)
699
            numpy_fn = getattr(np.fft, fname)
700

701
            for n, d in test_args:
702
                args = (n,) if d is None else (n, d)
703
                expected = numpy_fn(*args)
704
                actual = torch_fn(*args, device=device, dtype=dtype)
705
                self.assertEqual(actual, expected, exact_dtype=False)
706

707
    @skipCPUIfNoFFT
708
    @onlyNativeDeviceTypes
709
    @dtypes(torch.float, torch.double)
710
    def test_fftfreq_out(self, device, dtype):
711
        for func in (torch.fft.fftfreq, torch.fft.rfftfreq):
712
            expect = func(n=100, d=.5, device=device, dtype=dtype)
713
            actual = torch.empty((), device=device, dtype=dtype)
714
            with self.assertWarnsRegex(UserWarning, "out tensor will be resized"):
715
                func(n=100, d=.5, out=actual)
716
            self.assertEqual(actual, expect)
717

718

719
    @skipCPUIfNoFFT
720
    @onlyNativeDeviceTypes
721
    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
722
    @dtypes(torch.float, torch.double, torch.complex64, torch.complex128)
723
    def test_fftshift_numpy(self, device, dtype):
724
        test_args = [
725
            # shape, dim
726
            *product(((11,), (12,)), (None, 0, -1)),
727
            *product(((4, 5), (6, 6)), (None, 0, (-1,))),
728
            *product(((1, 1, 4, 6, 7, 2),), (None, (3, 4))),
729
        ]
730

731
        functions = ['fftshift', 'ifftshift']
732

733
        for shape, dim in test_args:
734
            input = torch.rand(*shape, device=device, dtype=dtype)
735
            input_np = input.cpu().numpy()
736

737
            for fname in functions:
738
                torch_fn = getattr(torch.fft, fname)
739
                numpy_fn = getattr(np.fft, fname)
740

741
                expected = numpy_fn(input_np, axes=dim)
742
                actual = torch_fn(input, dim=dim)
743
                self.assertEqual(actual, expected)
744

745
    @skipCPUIfNoFFT
746
    @onlyNativeDeviceTypes
747
    @unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
748
    @dtypes(torch.float, torch.double)
749
    def test_fftshift_frequencies(self, device, dtype):
750
        for n in range(10, 15):
751
            sorted_fft_freqs = torch.arange(-(n // 2), n - (n // 2),
752
                                            device=device, dtype=dtype)
753
            x = torch.fft.fftfreq(n, d=1 / n, device=device, dtype=dtype)
754

755
            # Test fftshift sorts the fftfreq output
756
            shifted = torch.fft.fftshift(x)
757
            self.assertEqual(shifted, shifted.sort().values)
758
            self.assertEqual(sorted_fft_freqs, shifted)
759

760
            # And ifftshift is the inverse
761
            self.assertEqual(x, torch.fft.ifftshift(shifted))
762

763
    # Legacy fft tests
764
    def _test_fft_ifft_rfft_irfft(self, device, dtype):
765
        complex_dtype = corresponding_complex_dtype(dtype)
766

767
        def _test_complex(sizes, signal_ndim, prepro_fn=lambda x: x):
768
            x = prepro_fn(torch.randn(*sizes, dtype=complex_dtype, device=device))
769
            dim = tuple(range(-signal_ndim, 0))
770
            for norm in ('ortho', None):
771
                res = torch.fft.fftn(x, dim=dim, norm=norm)
772
                rec = torch.fft.ifftn(res, dim=dim, norm=norm)
773
                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='fft and ifft')
774
                res = torch.fft.ifftn(x, dim=dim, norm=norm)
775
                rec = torch.fft.fftn(res, dim=dim, norm=norm)
776
                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='ifft and fft')
777

778
        def _test_real(sizes, signal_ndim, prepro_fn=lambda x: x):
779
            x = prepro_fn(torch.randn(*sizes, dtype=dtype, device=device))
780
            signal_numel = 1
781
            signal_sizes = x.size()[-signal_ndim:]
782
            dim = tuple(range(-signal_ndim, 0))
783
            for norm in (None, 'ortho'):
784
                res = torch.fft.rfftn(x, dim=dim, norm=norm)
785
                rec = torch.fft.irfftn(res, s=signal_sizes, dim=dim, norm=norm)
786
                self.assertEqual(x, rec, atol=1e-8, rtol=0, msg='rfft and irfft')
787
                res = torch.fft.fftn(x, dim=dim, norm=norm)
788
                rec = torch.fft.ifftn(res, dim=dim, norm=norm)
789
                x_complex = torch.complex(x, torch.zeros_like(x))
790
                self.assertEqual(x_complex, rec, atol=1e-8, rtol=0, msg='fft and ifft (from real)')
791

792
        # contiguous case
793
        _test_real((100,), 1)
794
        _test_real((10, 1, 10, 100), 1)
795
        _test_real((100, 100), 2)
796
        _test_real((2, 2, 5, 80, 60), 2)
797
        _test_real((50, 40, 70), 3)
798
        _test_real((30, 1, 50, 25, 20), 3)
799

800
        _test_complex((100,), 1)
801
        _test_complex((100, 100), 1)
802
        _test_complex((100, 100), 2)
803
        _test_complex((1, 20, 80, 60), 2)
804
        _test_complex((50, 40, 70), 3)
805
        _test_complex((6, 5, 50, 25, 20), 3)
806

807
        # non-contiguous case
808
        _test_real((165,), 1, lambda x: x.narrow(0, 25, 100))  # input is not aligned to complex type
809
        _test_real((100, 100, 3), 1, lambda x: x[:, :, 0])
810
        _test_real((100, 100), 2, lambda x: x.t())
811
        _test_real((20, 100, 10, 10), 2, lambda x: x.view(20, 100, 100)[:, :60])
812
        _test_real((65, 80, 115), 3, lambda x: x[10:60, 13:53, 10:80])
813
        _test_real((30, 20, 50, 25), 3, lambda x: x.transpose(1, 2).transpose(2, 3))
814

815
        _test_complex((100,), 1, lambda x: x.expand(100, 100))
816
        _test_complex((20, 90, 110), 2, lambda x: x[:, 5:85].narrow(2, 5, 100))
817
        _test_complex((40, 60, 3, 80), 3, lambda x: x.transpose(2, 0).select(0, 2)[5:55, :, 10:])
818
        _test_complex((30, 55, 50, 22), 3, lambda x: x[:, 3:53, 15:40, 1:21])
819

820
    @skipCPUIfNoFFT
821
    @onlyNativeDeviceTypes
822
    @dtypes(torch.double)
823
    def test_fft_ifft_rfft_irfft(self, device, dtype):
824
        self._test_fft_ifft_rfft_irfft(device, dtype)
825

826
    @deviceCountAtLeast(1)
827
    @onlyCUDA
828
    @dtypes(torch.double)
829
    def test_cufft_plan_cache(self, devices, dtype):
830
        @contextmanager
831
        def plan_cache_max_size(device, n):
832
            if device is None:
833
                plan_cache = torch.backends.cuda.cufft_plan_cache
834
            else:
835
                plan_cache = torch.backends.cuda.cufft_plan_cache[device]
836
            original = plan_cache.max_size
837
            plan_cache.max_size = n
838
            try:
839
                yield
840
            finally:
841
                plan_cache.max_size = original
842

843
        with plan_cache_max_size(devices[0], max(1, torch.backends.cuda.cufft_plan_cache.size - 10)):
844
            self._test_fft_ifft_rfft_irfft(devices[0], dtype)
845

846
        with plan_cache_max_size(devices[0], 0):
847
            self._test_fft_ifft_rfft_irfft(devices[0], dtype)
848

849
        torch.backends.cuda.cufft_plan_cache.clear()
850

851
        # check that stll works after clearing cache
852
        with plan_cache_max_size(devices[0], 10):
853
            self._test_fft_ifft_rfft_irfft(devices[0], dtype)
854

855
        with self.assertRaisesRegex(RuntimeError, r"must be non-negative"):
856
            torch.backends.cuda.cufft_plan_cache.max_size = -1
857

858
        with self.assertRaisesRegex(RuntimeError, r"read-only property"):
859
            torch.backends.cuda.cufft_plan_cache.size = -1
860

861
        with self.assertRaisesRegex(RuntimeError, r"but got device with index"):
862
            torch.backends.cuda.cufft_plan_cache[torch.cuda.device_count() + 10]
863

864
        # Multigpu tests
865
        if len(devices) > 1:
866
            # Test that different GPU has different cache
867
            x0 = torch.randn(2, 3, 3, device=devices[0])
868
            x1 = x0.to(devices[1])
869
            self.assertEqual(torch.fft.rfftn(x0, dim=(-2, -1)), torch.fft.rfftn(x1, dim=(-2, -1)))
870
            # If a plan is used across different devices, the following line (or
871
            # the assert above) would trigger illegal memory access. Other ways
872
            # to trigger the error include
873
            #   (1) setting CUDA_LAUNCH_BLOCKING=1 (pytorch/pytorch#19224) and
874
            #   (2) printing a device 1 tensor.
875
            x0.copy_(x1)
876

877
            # Test that un-indexed `torch.backends.cuda.cufft_plan_cache` uses current device
878
            with plan_cache_max_size(devices[0], 10):
879
                with plan_cache_max_size(devices[1], 11):
880
                    self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
881
                    self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)
882

883
                    self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
884
                    with torch.cuda.device(devices[1]):
885
                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
886
                        with torch.cuda.device(devices[0]):
887
                            self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
888

889
                self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
890
                with torch.cuda.device(devices[1]):
891
                    with plan_cache_max_size(None, 11):  # default is cuda:1
892
                        self.assertEqual(torch.backends.cuda.cufft_plan_cache[0].max_size, 10)
893
                        self.assertEqual(torch.backends.cuda.cufft_plan_cache[1].max_size, 11)
894

895
                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
896
                        with torch.cuda.device(devices[0]):
897
                            self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 10)  # default is cuda:0
898
                        self.assertEqual(torch.backends.cuda.cufft_plan_cache.max_size, 11)  # default is cuda:1
899

900
    @onlyCUDA
901
    @dtypes(torch.cfloat, torch.cdouble)
902
    def test_cufft_context(self, device, dtype):
903
        # Regression test for https://github.com/pytorch/pytorch/issues/109448
904
        x = torch.randn(32, dtype=dtype, device=device, requires_grad=True)
905
        dout = torch.zeros(32, dtype=dtype, device=device)
906

907
        # compute iFFT(FFT(x))
908
        out = torch.fft.ifft(torch.fft.fft(x))
909
        out.backward(dout, retain_graph=True)
910

911
        dx = torch.fft.fft(torch.fft.ifft(dout))
912

913
        self.assertTrue((x.grad - dx).abs().max() == 0)
914
        self.assertFalse((x.grad - x).abs().max() == 0)
915

916
    # passes on ROCm w/ python 2.7, fails w/ python 3.6
917
    @skipIfTorchDynamo("cannot set WRITEABLE flag to True of this array")
918
    @skipCPUIfNoFFT
919
    @onlyNativeDeviceTypes
920
    @dtypes(torch.double)
921
    def test_stft(self, device, dtype):
922
        if not TEST_LIBROSA:
923
            raise unittest.SkipTest('librosa not found')
924

925
        def librosa_stft(x, n_fft, hop_length, win_length, window, center):
926
            if window is None:
927
                window = np.ones(n_fft if win_length is None else win_length)
928
            else:
929
                window = window.cpu().numpy()
930
            input_1d = x.dim() == 1
931
            if input_1d:
932
                x = x.view(1, -1)
933

934
            # NOTE: librosa 0.9 changed default pad_mode to 'constant' (zero padding)
935
            # however, we use the pre-0.9 default ('reflect')
936
            pad_mode = 'reflect'
937

938
            result = []
939
            for xi in x:
940
                ri = librosa.stft(xi.cpu().numpy(), n_fft=n_fft, hop_length=hop_length,
941
                                  win_length=win_length, window=window, center=center,
942
                                  pad_mode=pad_mode)
943
                result.append(torch.from_numpy(np.stack([ri.real, ri.imag], -1)))
944
            result = torch.stack(result, 0)
945
            if input_1d:
946
                result = result[0]
947
            return result
948

949
        def _test(sizes, n_fft, hop_length=None, win_length=None, win_sizes=None,
950
                  center=True, expected_error=None):
951
            x = torch.randn(*sizes, dtype=dtype, device=device)
952
            if win_sizes is not None:
953
                window = torch.randn(*win_sizes, dtype=dtype, device=device)
954
            else:
955
                window = None
956
            if expected_error is None:
957
                result = x.stft(n_fft, hop_length, win_length, window,
958
                                center=center, return_complex=False)
959
                # NB: librosa defaults to np.complex64 output, no matter what
960
                # the input dtype
961
                ref_result = librosa_stft(x, n_fft, hop_length, win_length, window, center)
962
                self.assertEqual(result, ref_result, atol=7e-6, rtol=0, msg='stft comparison against librosa', exact_dtype=False)
963
                # With return_complex=True, the result is the same but viewed as complex instead of real
964
                result_complex = x.stft(n_fft, hop_length, win_length, window, center=center, return_complex=True)
965
                self.assertEqual(result_complex, torch.view_as_complex(result))
966
            else:
967
                self.assertRaises(expected_error,
968
                                  lambda: x.stft(n_fft, hop_length, win_length, window, center=center))
969

970
        for center in [True, False]:
971
            _test((10,), 7, center=center)
972
            _test((10, 4000), 1024, center=center)
973

974
            _test((10,), 7, 2, center=center)
975
            _test((10, 4000), 1024, 512, center=center)
976

977
            _test((10,), 7, 2, win_sizes=(7,), center=center)
978
            _test((10, 4000), 1024, 512, win_sizes=(1024,), center=center)
979

980
            # spectral oversample
981
            _test((10,), 7, 2, win_length=5, center=center)
982
            _test((10, 4000), 1024, 512, win_length=100, center=center)
983

984
        _test((10, 4, 2), 1, 1, expected_error=RuntimeError)
985
        _test((10,), 11, 1, center=False, expected_error=RuntimeError)
986
        _test((10,), -1, 1, expected_error=RuntimeError)
987
        _test((10,), 3, win_length=5, expected_error=RuntimeError)
988
        _test((10,), 5, 4, win_sizes=(11,), expected_error=RuntimeError)
989
        _test((10,), 5, 4, win_sizes=(1, 1), expected_error=RuntimeError)
990

991
    @skipIfTorchDynamo("double")
992
    @skipCPUIfNoFFT
993
    @onlyNativeDeviceTypes
994
    @dtypes(torch.double)
995
    def test_istft_against_librosa(self, device, dtype):
996
        if not TEST_LIBROSA:
997
            raise unittest.SkipTest('librosa not found')
998

999
        def librosa_istft(x, n_fft, hop_length, win_length, window, length, center):
1000
            if window is None:
1001
                window = np.ones(n_fft if win_length is None else win_length)
1002
            else:
1003
                window = window.cpu().numpy()
1004

1005
            return librosa.istft(x.cpu().numpy(), n_fft=n_fft, hop_length=hop_length,
1006
                                 win_length=win_length, length=length, window=window, center=center)
1007

1008
        def _test(size, n_fft, hop_length=None, win_length=None, win_sizes=None,
1009
                  length=None, center=True):
1010
            x = torch.randn(size, dtype=dtype, device=device)
1011
            if win_sizes is not None:
1012
                window = torch.randn(*win_sizes, dtype=dtype, device=device)
1013
            else:
1014
                window = None
1015

1016
            x_stft = x.stft(n_fft, hop_length, win_length, window, center=center,
1017
                            onesided=True, return_complex=True)
1018

1019
            ref_result = librosa_istft(x_stft, n_fft, hop_length, win_length,
1020
                                       window, length, center)
1021
            result = x_stft.istft(n_fft, hop_length, win_length, window,
1022
                                  length=length, center=center)
1023
            self.assertEqual(result, ref_result)
1024

1025
        for center in [True, False]:
1026
            _test(10, 7, center=center)
1027
            _test(4000, 1024, center=center)
1028
            _test(4000, 1024, center=center, length=4000)
1029

1030
            _test(10, 7, 2, center=center)
1031
            _test(4000, 1024, 512, center=center)
1032
            _test(4000, 1024, 512, center=center, length=4000)
1033

1034
            _test(10, 7, 2, win_sizes=(7,), center=center)
1035
            _test(4000, 1024, 512, win_sizes=(1024,), center=center)
1036
            _test(4000, 1024, 512, win_sizes=(1024,), center=center, length=4000)
1037

1038
    @onlyNativeDeviceTypes
1039
    @skipCPUIfNoFFT
1040
    @dtypes(torch.double, torch.cdouble)
1041
    def test_complex_stft_roundtrip(self, device, dtype):
1042
        test_args = list(product(
1043
            # input
1044
            (torch.randn(600, device=device, dtype=dtype),
1045
             torch.randn(807, device=device, dtype=dtype),
1046
             torch.randn(12, 60, device=device, dtype=dtype)),
1047
            # n_fft
1048
            (50, 27),
1049
            # hop_length
1050
            (None, 10),
1051
            # center
1052
            (True,),
1053
            # pad_mode
1054
            ("constant", "reflect", "circular"),
1055
            # normalized
1056
            (True, False),
1057
            # onesided
1058
            (True, False) if not dtype.is_complex else (False,),
1059
        ))
1060

1061
        for args in test_args:
1062
            x, n_fft, hop_length, center, pad_mode, normalized, onesided = args
1063
            common_kwargs = {
1064
                'n_fft': n_fft, 'hop_length': hop_length, 'center': center,
1065
                'normalized': normalized, 'onesided': onesided,
1066
            }
1067

1068
            # Functional interface
1069
            x_stft = torch.stft(x, pad_mode=pad_mode, return_complex=True, **common_kwargs)
1070
            x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex,
1071
                                      length=x.size(-1), **common_kwargs)
1072
            self.assertEqual(x_roundtrip, x)
1073

1074
            # Tensor method interface
1075
            x_stft = x.stft(pad_mode=pad_mode, return_complex=True, **common_kwargs)
1076
            x_roundtrip = torch.istft(x_stft, return_complex=dtype.is_complex,
1077
                                      length=x.size(-1), **common_kwargs)
1078
            self.assertEqual(x_roundtrip, x)
1079

1080
    @onlyNativeDeviceTypes
1081
    @skipCPUIfNoFFT
1082
    @dtypes(torch.double, torch.cdouble)
1083
    def test_stft_roundtrip_complex_window(self, device, dtype):
1084
        test_args = list(product(
1085
            # input
1086
            (torch.randn(600, device=device, dtype=dtype),
1087
             torch.randn(807, device=device, dtype=dtype),
1088
             torch.randn(12, 60, device=device, dtype=dtype)),
1089
            # n_fft
1090
            (50, 27),
1091
            # hop_length
1092
            (None, 10),
1093
            # pad_mode
1094
            ("constant", "reflect", "replicate", "circular"),
1095
            # normalized
1096
            (True, False),
1097
        ))
1098
        for args in test_args:
1099
            x, n_fft, hop_length, pad_mode, normalized = args
1100
            window = torch.rand(n_fft, device=device, dtype=torch.cdouble)
1101
            x_stft = torch.stft(
1102
                x, n_fft=n_fft, hop_length=hop_length, window=window,
1103
                center=True, pad_mode=pad_mode, normalized=normalized)
1104
            self.assertEqual(x_stft.dtype, torch.cdouble)
1105
            self.assertEqual(x_stft.size(-2), n_fft)  # Not onesided
1106

1107
            x_roundtrip = torch.istft(
1108
                x_stft, n_fft=n_fft, hop_length=hop_length, window=window,
1109
                center=True, normalized=normalized, length=x.size(-1),
1110
                return_complex=True)
1111
            self.assertEqual(x_stft.dtype, torch.cdouble)
1112

1113
            if not dtype.is_complex:
1114
                self.assertEqual(x_roundtrip.imag, torch.zeros_like(x_roundtrip.imag),
1115
                                 atol=1e-6, rtol=0)
1116
                self.assertEqual(x_roundtrip.real, x)
1117
            else:
1118
                self.assertEqual(x_roundtrip, x)
1119

1120

1121
    @skipCPUIfNoFFT
1122
    @dtypes(torch.cdouble)
1123
    def test_complex_stft_definition(self, device, dtype):
1124
        test_args = list(product(
1125
            # input
1126
            (torch.randn(600, device=device, dtype=dtype),
1127
             torch.randn(807, device=device, dtype=dtype)),
1128
            # n_fft
1129
            (50, 27),
1130
            # hop_length
1131
            (10, 15)
1132
        ))
1133

1134
        for args in test_args:
1135
            window = torch.randn(args[1], device=device, dtype=dtype)
1136
            expected = _stft_reference(args[0], args[2], window)
1137
            actual = torch.stft(*args, window=window, center=False)
1138
            self.assertEqual(actual, expected)
1139

1140
    @onlyNativeDeviceTypes
1141
    @skipCPUIfNoFFT
1142
    @dtypes(torch.cdouble)
1143
    def test_complex_stft_real_equiv(self, device, dtype):
1144
        test_args = list(product(
1145
            # input
1146
            (torch.rand(600, device=device, dtype=dtype),
1147
             torch.rand(807, device=device, dtype=dtype),
1148
             torch.rand(14, 50, device=device, dtype=dtype),
1149
             torch.rand(6, 51, device=device, dtype=dtype)),
1150
            # n_fft
1151
            (50, 27),
1152
            # hop_length
1153
            (None, 10),
1154
            # win_length
1155
            (None, 20),
1156
            # center
1157
            (False, True),
1158
            # pad_mode
1159
            ("constant", "reflect", "circular"),
1160
            # normalized
1161
            (True, False),
1162
        ))
1163

1164
        for args in test_args:
1165
            x, n_fft, hop_length, win_length, center, pad_mode, normalized = args
1166
            expected = _complex_stft(x, n_fft, hop_length=hop_length,
1167
                                     win_length=win_length, pad_mode=pad_mode,
1168
                                     center=center, normalized=normalized)
1169
            actual = torch.stft(x, n_fft, hop_length=hop_length,
1170
                                win_length=win_length, pad_mode=pad_mode,
1171
                                center=center, normalized=normalized)
1172
            self.assertEqual(expected, actual)
1173

1174
    @skipCPUIfNoFFT
1175
    @dtypes(torch.cdouble)
1176
    def test_complex_istft_real_equiv(self, device, dtype):
1177
        test_args = list(product(
1178
            # input
1179
            (torch.rand(40, 20, device=device, dtype=dtype),
1180
             torch.rand(25, 1, device=device, dtype=dtype),
1181
             torch.rand(4, 20, 10, device=device, dtype=dtype)),
1182
            # hop_length
1183
            (None, 10),
1184
            # center
1185
            (False, True),
1186
            # normalized
1187
            (True, False),
1188
        ))
1189

1190
        for args in test_args:
1191
            x, hop_length, center, normalized = args
1192
            n_fft = x.size(-2)
1193
            expected = _complex_istft(x, n_fft, hop_length=hop_length,
1194
                                      center=center, normalized=normalized)
1195
            actual = torch.istft(x, n_fft, hop_length=hop_length,
1196
                                 center=center, normalized=normalized,
1197
                                 return_complex=True)
1198
            self.assertEqual(expected, actual)
1199

1200
    @skipCPUIfNoFFT
1201
    def test_complex_stft_onesided(self, device):
1202
        # stft of complex input cannot be onesided
1203
        for x_dtype, window_dtype in product((torch.double, torch.cdouble), repeat=2):
1204
            x = torch.rand(100, device=device, dtype=x_dtype)
1205
            window = torch.rand(10, device=device, dtype=window_dtype)
1206

1207
            if x_dtype.is_complex or window_dtype.is_complex:
1208
                with self.assertRaisesRegex(RuntimeError, 'complex'):
1209
                    x.stft(10, window=window, pad_mode='constant', onesided=True)
1210
            else:
1211
                y = x.stft(10, window=window, pad_mode='constant', onesided=True,
1212
                           return_complex=True)
1213
                self.assertEqual(y.dtype, torch.cdouble)
1214
                self.assertEqual(y.size(), (6, 51))
1215

1216
        x = torch.rand(100, device=device, dtype=torch.cdouble)
1217
        with self.assertRaisesRegex(RuntimeError, 'complex'):
1218
            x.stft(10, pad_mode='constant', onesided=True)
1219

1220
    # stft is currently warning that it requires return-complex while an upgrader is written
1221
    @onlyNativeDeviceTypes
1222
    @skipCPUIfNoFFT
1223
    def test_stft_requires_complex(self, device):
1224
        x = torch.rand(100)
1225
        with self.assertRaisesRegex(RuntimeError, 'stft requires the return_complex parameter'):
1226
            y = x.stft(10, pad_mode='constant')
1227

1228
    # stft and istft are currently warning if a window is not provided
1229
    @onlyNativeDeviceTypes
1230
    @skipCPUIfNoFFT
1231
    def test_stft_requires_window(self, device):
1232
        x = torch.rand(100)
1233
        with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"):
1234
            y = x.stft(10, pad_mode='constant', return_complex=True)
1235

1236
    @onlyNativeDeviceTypes
1237
    @skipCPUIfNoFFT
1238
    def test_istft_requires_window(self, device):
1239
        stft = torch.rand((51, 5), dtype=torch.cdouble)
1240
        # 51 = 2 * n_fft + 1, 5 = number of frames
1241
        with self.assertWarnsOnceRegex(UserWarning, "A window was not provided"):
1242
            x = torch.istft(stft, n_fft=100, length=100)
1243

1244
    @skipCPUIfNoFFT
1245
    def test_fft_input_modification(self, device):
1246
        # FFT functions should not modify their input (gh-34551)
1247

1248
        signal = torch.ones((2, 2, 2), device=device)
1249
        signal_copy = signal.clone()
1250
        spectrum = torch.fft.fftn(signal, dim=(-2, -1))
1251
        self.assertEqual(signal, signal_copy)
1252

1253
        spectrum_copy = spectrum.clone()
1254
        _ = torch.fft.ifftn(spectrum, dim=(-2, -1))
1255
        self.assertEqual(spectrum, spectrum_copy)
1256

1257
        half_spectrum = torch.fft.rfftn(signal, dim=(-2, -1))
1258
        self.assertEqual(signal, signal_copy)
1259

1260
        half_spectrum_copy = half_spectrum.clone()
1261
        _ = torch.fft.irfftn(half_spectrum_copy, s=(2, 2), dim=(-2, -1))
1262
        self.assertEqual(half_spectrum, half_spectrum_copy)
1263

1264
    @onlyNativeDeviceTypes
1265
    @skipCPUIfNoFFT
1266
    def test_fft_plan_repeatable(self, device):
1267
        # Regression test for gh-58724 and gh-63152
1268
        for n in [2048, 3199, 5999]:
1269
            a = torch.randn(n, device=device, dtype=torch.complex64)
1270
            res1 = torch.fft.fftn(a)
1271
            res2 = torch.fft.fftn(a.clone())
1272
            self.assertEqual(res1, res2)
1273

1274
            a = torch.randn(n, device=device, dtype=torch.float64)
1275
            res1 = torch.fft.rfft(a)
1276
            res2 = torch.fft.rfft(a.clone())
1277
            self.assertEqual(res1, res2)
1278

1279
    @onlyNativeDeviceTypes
1280
    @skipCPUIfNoFFT
1281
    @dtypes(torch.double)
1282
    def test_istft_round_trip_simple_cases(self, device, dtype):
1283
        """stft -> istft should recover the original signale"""
1284
        def _test(input, n_fft, length):
1285
            stft = torch.stft(input, n_fft=n_fft, return_complex=True)
1286
            inverse = torch.istft(stft, n_fft=n_fft, length=length)
1287
            self.assertEqual(input, inverse, exact_dtype=True)
1288

1289
        _test(torch.ones(4, dtype=dtype, device=device), 4, 4)
1290
        _test(torch.zeros(4, dtype=dtype, device=device), 4, 4)
1291

1292
    @onlyNativeDeviceTypes
1293
    @skipCPUIfNoFFT
1294
    @dtypes(torch.double)
1295
    def test_istft_round_trip_various_params(self, device, dtype):
1296
        """stft -> istft should recover the original signale"""
1297
        def _test_istft_is_inverse_of_stft(stft_kwargs):
1298
            # generates a random sound signal for each tril and then does the stft/istft
1299
            # operation to check whether we can reconstruct signal
1300
            data_sizes = [(2, 20), (3, 15), (4, 10)]
1301
            num_trials = 100
1302
            istft_kwargs = stft_kwargs.copy()
1303
            del istft_kwargs['pad_mode']
1304
            for sizes in data_sizes:
1305
                for i in range(num_trials):
1306
                    original = torch.randn(*sizes, dtype=dtype, device=device)
1307
                    stft = torch.stft(original, return_complex=True, **stft_kwargs)
1308
                    inversed = torch.istft(stft, length=original.size(1), **istft_kwargs)
1309
                    self.assertEqual(
1310
                        inversed, original, msg='istft comparison against original',
1311
                        atol=7e-6, rtol=0, exact_dtype=True)
1312

1313
        patterns = [
1314
            # hann_window, centered, normalized, onesided
1315
            {
1316
                'n_fft': 12,
1317
                'hop_length': 4,
1318
                'win_length': 12,
1319
                'window': torch.hann_window(12, dtype=dtype, device=device),
1320
                'center': True,
1321
                'pad_mode': 'reflect',
1322
                'normalized': True,
1323
                'onesided': True,
1324
            },
1325
            # hann_window, centered, not normalized, not onesided
1326
            {
1327
                'n_fft': 12,
1328
                'hop_length': 2,
1329
                'win_length': 8,
1330
                'window': torch.hann_window(8, dtype=dtype, device=device),
1331
                'center': True,
1332
                'pad_mode': 'reflect',
1333
                'normalized': False,
1334
                'onesided': False,
1335
            },
1336
            # hamming_window, centered, normalized, not onesided
1337
            {
1338
                'n_fft': 15,
1339
                'hop_length': 3,
1340
                'win_length': 11,
1341
                'window': torch.hamming_window(11, dtype=dtype, device=device),
1342
                'center': True,
1343
                'pad_mode': 'constant',
1344
                'normalized': True,
1345
                'onesided': False,
1346
            },
1347
            # hamming_window, centered, not normalized, onesided
1348
            # window same size as n_fft
1349
            {
1350
                'n_fft': 5,
1351
                'hop_length': 2,
1352
                'win_length': 5,
1353
                'window': torch.hamming_window(5, dtype=dtype, device=device),
1354
                'center': True,
1355
                'pad_mode': 'constant',
1356
                'normalized': False,
1357
                'onesided': True,
1358
            },
1359
        ]
1360
        for i, pattern in enumerate(patterns):
1361
            _test_istft_is_inverse_of_stft(pattern)
1362

1363
    @onlyNativeDeviceTypes
1364
    @skipCPUIfNoFFT
1365
    @dtypes(torch.double)
1366
    def test_istft_round_trip_with_padding(self, device, dtype):
1367
        """long hop_length or not centered may cause length mismatch in the inversed signal"""
1368
        def _test_istft_is_inverse_of_stft_with_padding(stft_kwargs):
1369
            # generates a random sound signal for each tril and then does the stft/istft
1370
            # operation to check whether we can reconstruct signal
1371
            num_trials = 100
1372
            sizes = stft_kwargs['size']
1373
            del stft_kwargs['size']
1374
            istft_kwargs = stft_kwargs.copy()
1375
            del istft_kwargs['pad_mode']
1376
            for i in range(num_trials):
1377
                original = torch.randn(*sizes, dtype=dtype, device=device)
1378
                stft = torch.stft(original, return_complex=True, **stft_kwargs)
1379
                with self.assertWarnsOnceRegex(UserWarning, "The length of signal is shorter than the length parameter."):
1380
                    inversed = torch.istft(stft, length=original.size(-1), **istft_kwargs)
1381
                n_frames = stft.size(-1)
1382
                if stft_kwargs["center"] is True:
1383
                    len_expected = stft_kwargs["n_fft"] // 2 + stft_kwargs["hop_length"] * (n_frames - 1)
1384
                else:
1385
                    len_expected = stft_kwargs["n_fft"] + stft_kwargs["hop_length"] * (n_frames - 1)
1386
                # trim the original for case when constructed signal is shorter than original
1387
                padding = inversed[..., len_expected:]
1388
                inversed = inversed[..., :len_expected]
1389
                original = original[..., :len_expected]
1390
                # test the padding points of the inversed signal are all zeros
1391
                zeros = torch.zeros_like(padding, device=padding.device)
1392
                self.assertEqual(
1393
                    padding, zeros, msg='istft padding values against zeros',
1394
                    atol=7e-6, rtol=0, exact_dtype=True)
1395
                self.assertEqual(
1396
                    inversed, original, msg='istft comparison against original',
1397
                    atol=7e-6, rtol=0, exact_dtype=True)
1398

1399
        patterns = [
1400
            # hamming_window, not centered, not normalized, not onesided
1401
            # window same size as n_fft
1402
            {
1403
                'size': [2, 20],
1404
                'n_fft': 3,
1405
                'hop_length': 2,
1406
                'win_length': 3,
1407
                'window': torch.hamming_window(3, dtype=dtype, device=device),
1408
                'center': False,
1409
                'pad_mode': 'reflect',
1410
                'normalized': False,
1411
                'onesided': False,
1412
            },
1413
            # hamming_window, centered, not normalized, onesided, long hop_length
1414
            # window same size as n_fft
1415
            {
1416
                'size': [2, 500],
1417
                'n_fft': 256,
1418
                'hop_length': 254,
1419
                'win_length': 256,
1420
                'window': torch.hamming_window(256, dtype=dtype, device=device),
1421
                'center': True,
1422
                'pad_mode': 'constant',
1423
                'normalized': False,
1424
                'onesided': True,
1425
            },
1426
        ]
1427
        for i, pattern in enumerate(patterns):
1428
            _test_istft_is_inverse_of_stft_with_padding(pattern)
1429

1430
    @onlyNativeDeviceTypes
1431
    def test_istft_throws(self, device):
1432
        """istft should throw exception for invalid parameters"""
1433
        stft = torch.zeros((3, 5, 2), device=device)
1434
        # the window is size 1 but it hops 20 so there is a gap which throw an error
1435
        self.assertRaises(
1436
            RuntimeError, torch.istft, stft, n_fft=4,
1437
            hop_length=20, win_length=1, window=torch.ones(1))
1438
        # A window of zeros does not meet NOLA
1439
        invalid_window = torch.zeros(4, device=device)
1440
        self.assertRaises(
1441
            RuntimeError, torch.istft, stft, n_fft=4, win_length=4, window=invalid_window)
1442
        # Input cannot be empty
1443
        self.assertRaises(RuntimeError, torch.istft, torch.zeros((3, 0, 2)), 2)
1444
        self.assertRaises(RuntimeError, torch.istft, torch.zeros((0, 3, 2)), 2)
1445

1446
    @skipIfTorchDynamo("Failed running call_function")
1447
    @onlyNativeDeviceTypes
1448
    @skipCPUIfNoFFT
1449
    @dtypes(torch.double)
1450
    def test_istft_of_sine(self, device, dtype):
1451
        complex_dtype = corresponding_complex_dtype(dtype)
1452

1453
        def _test(amplitude, L, n):
1454
            # stft of amplitude*sin(2*pi/L*n*x) with the hop length and window size equaling L
1455
            x = torch.arange(2 * L + 1, device=device, dtype=dtype)
1456
            original = amplitude * torch.sin(2 * math.pi / L * x * n)
1457
            # stft = torch.stft(original, L, hop_length=L, win_length=L,
1458
            #                   window=torch.ones(L), center=False, normalized=False)
1459
            stft = torch.zeros((L // 2 + 1, 2), device=device, dtype=complex_dtype)
1460
            stft_largest_val = (amplitude * L) / 2.0
1461
            if n < stft.size(0):
1462
                stft[n].imag = torch.tensor(-stft_largest_val, dtype=dtype)
1463

1464
            if 0 <= L - n < stft.size(0):
1465
                # symmetric about L // 2
1466
                stft[L - n].imag = torch.tensor(stft_largest_val, dtype=dtype)
1467

1468
            inverse = torch.istft(
1469
                stft, L, hop_length=L, win_length=L,
1470
                window=torch.ones(L, device=device, dtype=dtype), center=False, normalized=False)
1471
            # There is a larger error due to the scaling of amplitude
1472
            original = original[..., :inverse.size(-1)]
1473
            self.assertEqual(inverse, original, atol=1e-3, rtol=0)
1474

1475
        _test(amplitude=123, L=5, n=1)
1476
        _test(amplitude=150, L=5, n=2)
1477
        _test(amplitude=111, L=5, n=3)
1478
        _test(amplitude=160, L=7, n=4)
1479
        _test(amplitude=145, L=8, n=5)
1480
        _test(amplitude=80, L=9, n=6)
1481
        _test(amplitude=99, L=10, n=7)
1482

1483
    @onlyNativeDeviceTypes
1484
    @skipCPUIfNoFFT
1485
    @dtypes(torch.double)
1486
    def test_istft_linearity(self, device, dtype):
1487
        num_trials = 100
1488
        complex_dtype = corresponding_complex_dtype(dtype)
1489

1490
        def _test(data_size, kwargs):
1491
            for i in range(num_trials):
1492
                tensor1 = torch.randn(data_size, device=device, dtype=complex_dtype)
1493
                tensor2 = torch.randn(data_size, device=device, dtype=complex_dtype)
1494
                a, b = torch.rand(2, dtype=dtype, device=device)
1495
                # Also compare method vs. functional call signature
1496
                istft1 = tensor1.istft(**kwargs)
1497
                istft2 = tensor2.istft(**kwargs)
1498
                istft = a * istft1 + b * istft2
1499
                estimate = torch.istft(a * tensor1 + b * tensor2, **kwargs)
1500
                self.assertEqual(istft, estimate, atol=1e-5, rtol=0)
1501
        patterns = [
1502
            # hann_window, centered, normalized, onesided
1503
            (
1504
                (2, 7, 7),
1505
                {
1506
                    'n_fft': 12,
1507
                    'window': torch.hann_window(12, device=device, dtype=dtype),
1508
                    'center': True,
1509
                    'normalized': True,
1510
                    'onesided': True,
1511
                },
1512
            ),
1513
            # hann_window, centered, not normalized, not onesided
1514
            (
1515
                (2, 12, 7),
1516
                {
1517
                    'n_fft': 12,
1518
                    'window': torch.hann_window(12, device=device, dtype=dtype),
1519
                    'center': True,
1520
                    'normalized': False,
1521
                    'onesided': False,
1522
                },
1523
            ),
1524
            # hamming_window, centered, normalized, not onesided
1525
            (
1526
                (2, 12, 7),
1527
                {
1528
                    'n_fft': 12,
1529
                    'window': torch.hamming_window(12, device=device, dtype=dtype),
1530
                    'center': True,
1531
                    'normalized': True,
1532
                    'onesided': False,
1533
                },
1534
            ),
1535
            # hamming_window, not centered, not normalized, onesided
1536
            (
1537
                (2, 7, 3),
1538
                {
1539
                    'n_fft': 12,
1540
                    'window': torch.hamming_window(12, device=device, dtype=dtype),
1541
                    'center': False,
1542
                    'normalized': False,
1543
                    'onesided': True,
1544
                },
1545
            )
1546
        ]
1547
        for data_size, kwargs in patterns:
1548
            _test(data_size, kwargs)
1549

1550
    @onlyNativeDeviceTypes
1551
    @skipCPUIfNoFFT
1552
    def test_batch_istft(self, device):
1553
        original = torch.tensor([
1554
            [4., 4., 4., 4., 4.],
1555
            [0., 0., 0., 0., 0.],
1556
            [0., 0., 0., 0., 0.]
1557
        ], device=device, dtype=torch.complex64)
1558

1559
        single = original.repeat(1, 1, 1)
1560
        multi = original.repeat(4, 1, 1)
1561

1562
        i_original = torch.istft(original, n_fft=4, length=4)
1563
        i_single = torch.istft(single, n_fft=4, length=4)
1564
        i_multi = torch.istft(multi, n_fft=4, length=4)
1565

1566
        self.assertEqual(i_original.repeat(1, 1), i_single, atol=1e-6, rtol=0, exact_dtype=True)
1567
        self.assertEqual(i_original.repeat(4, 1), i_multi, atol=1e-6, rtol=0, exact_dtype=True)
1568

1569
    @onlyCUDA
1570
    @skipIf(not TEST_MKL, "Test requires MKL")
1571
    def test_stft_window_device(self, device):
1572
        # Test the (i)stft window must be on the same device as the input
1573
        x = torch.randn(1000, dtype=torch.complex64)
1574
        window = torch.randn(100, dtype=torch.complex64)
1575

1576
        with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"):
1577
            torch.stft(x, n_fft=100, window=window.to(device))
1578

1579
        with self.assertRaisesRegex(RuntimeError, "stft input and window must be on the same device"):
1580
            torch.stft(x.to(device), n_fft=100, window=window)
1581

1582
        X = torch.stft(x, n_fft=100, window=window)
1583

1584
        with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"):
1585
            torch.istft(X, n_fft=100, window=window.to(device))
1586

1587
        with self.assertRaisesRegex(RuntimeError, "istft input and window must be on the same device"):
1588
            torch.istft(x.to(device), n_fft=100, window=window)
1589

1590

1591
class FFTDocTestFinder:
1592
    '''The default doctest finder doesn't like that function.__module__ doesn't
1593
    match torch.fft. It assumes the functions are leaked imports.
1594
    '''
1595
    def __init__(self) -> None:
1596
        self.parser = doctest.DocTestParser()
1597

1598
    def find(self, obj, name=None, module=None, globs=None, extraglobs=None):
1599
        doctests = []
1600

1601
        modname = name if name is not None else obj.__name__
1602
        globs = {} if globs is None else globs
1603

1604
        for fname in obj.__all__:
1605
            func = getattr(obj, fname)
1606
            if inspect.isroutine(func):
1607
                qualname = modname + '.' + fname
1608
                docstring = inspect.getdoc(func)
1609
                if docstring is None:
1610
                    continue
1611

1612
                examples = self.parser.get_doctest(
1613
                    docstring, globs=globs, name=fname, filename=None, lineno=None)
1614
                doctests.append(examples)
1615

1616
        return doctests
1617

1618

1619
class TestFFTDocExamples(TestCase):
1620
    pass
1621

1622
def generate_doc_test(doc_test):
1623
    def test(self, device):
1624
        self.assertEqual(device, 'cpu')
1625
        runner = doctest.DocTestRunner()
1626
        runner.run(doc_test)
1627

1628
        if runner.failures != 0:
1629
            runner.summarize()
1630
            self.fail('Doctest failed')
1631

1632
    setattr(TestFFTDocExamples, 'test_' + doc_test.name, skipCPUIfNoFFT(test))
1633

1634
for doc_test in FFTDocTestFinder().find(torch.fft, globs=dict(torch=torch)):
1635
    generate_doc_test(doc_test)
1636

1637

1638
instantiate_device_type_tests(TestFFT, globals())
1639
instantiate_device_type_tests(TestFFTDocExamples, globals(), only_for='cpu')
1640

1641
if __name__ == '__main__':
1642
    run_tests()
1643

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.