3
from __future__ import annotations
9
from . import _dtypes_impl, _util
10
from ._normalizations import ArrayLike, normalizer
14
"""NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""
16
@functools.wraps(func)
17
def wrapped(tensor, *args, **kwds):
19
_dtypes_impl.default_dtypes().complex_dtype
20
if tensor.is_complex()
21
else _dtypes_impl.default_dtypes().float_dtype
23
tensor = _util.cast_if_needed(tensor, target_dtype)
24
return func(tensor, *args, **kwds)
31
def fft(a: ArrayLike, n=None, axis=-1, norm=None):
32
return torch.fft.fft(a, n, dim=axis, norm=norm)
37
def ifft(a: ArrayLike, n=None, axis=-1, norm=None):
38
return torch.fft.ifft(a, n, dim=axis, norm=norm)
43
def rfft(a: ArrayLike, n=None, axis=-1, norm=None):
44
return torch.fft.rfft(a, n, dim=axis, norm=norm)
49
def irfft(a: ArrayLike, n=None, axis=-1, norm=None):
50
return torch.fft.irfft(a, n, dim=axis, norm=norm)
55
def fftn(a: ArrayLike, s=None, axes=None, norm=None):
56
return torch.fft.fftn(a, s, dim=axes, norm=norm)
61
def ifftn(a: ArrayLike, s=None, axes=None, norm=None):
62
return torch.fft.ifftn(a, s, dim=axes, norm=norm)
67
def rfftn(a: ArrayLike, s=None, axes=None, norm=None):
68
return torch.fft.rfftn(a, s, dim=axes, norm=norm)
73
def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
74
return torch.fft.irfftn(a, s, dim=axes, norm=norm)
79
def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
80
return torch.fft.fft2(a, s, dim=axes, norm=norm)
85
def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
86
return torch.fft.ifft2(a, s, dim=axes, norm=norm)
91
def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
92
return torch.fft.rfft2(a, s, dim=axes, norm=norm)
97
def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
98
return torch.fft.irfft2(a, s, dim=axes, norm=norm)
103
def hfft(a: ArrayLike, n=None, axis=-1, norm=None):
104
return torch.fft.hfft(a, n, dim=axis, norm=norm)
109
def ihfft(a: ArrayLike, n=None, axis=-1, norm=None):
110
return torch.fft.ihfft(a, n, dim=axis, norm=norm)
114
def fftfreq(n, d=1.0):
115
return torch.fft.fftfreq(n, d)
119
def rfftfreq(n, d=1.0):
120
return torch.fft.rfftfreq(n, d)
124
def fftshift(x: ArrayLike, axes=None):
125
return torch.fft.fftshift(x, axes)
129
def ifftshift(x: ArrayLike, axes=None):
130
return torch.fft.ifftshift(x, axes)