pytorch

Форк
0
130 строк · 2.7 Кб
1
# mypy: ignore-errors
2

3
from __future__ import annotations
4

5
import functools
6

7
import torch
8

9
from . import _dtypes_impl, _util
10
from ._normalizations import ArrayLike, normalizer
11

12

13
def upcast(func):
14
    """NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""
15

16
    @functools.wraps(func)
17
    def wrapped(tensor, *args, **kwds):
18
        target_dtype = (
19
            _dtypes_impl.default_dtypes().complex_dtype
20
            if tensor.is_complex()
21
            else _dtypes_impl.default_dtypes().float_dtype
22
        )
23
        tensor = _util.cast_if_needed(tensor, target_dtype)
24
        return func(tensor, *args, **kwds)
25

26
    return wrapped
27

28

29
@normalizer
30
@upcast
31
def fft(a: ArrayLike, n=None, axis=-1, norm=None):
32
    return torch.fft.fft(a, n, dim=axis, norm=norm)
33

34

35
@normalizer
36
@upcast
37
def ifft(a: ArrayLike, n=None, axis=-1, norm=None):
38
    return torch.fft.ifft(a, n, dim=axis, norm=norm)
39

40

41
@normalizer
42
@upcast
43
def rfft(a: ArrayLike, n=None, axis=-1, norm=None):
44
    return torch.fft.rfft(a, n, dim=axis, norm=norm)
45

46

47
@normalizer
48
@upcast
49
def irfft(a: ArrayLike, n=None, axis=-1, norm=None):
50
    return torch.fft.irfft(a, n, dim=axis, norm=norm)
51

52

53
@normalizer
54
@upcast
55
def fftn(a: ArrayLike, s=None, axes=None, norm=None):
56
    return torch.fft.fftn(a, s, dim=axes, norm=norm)
57

58

59
@normalizer
60
@upcast
61
def ifftn(a: ArrayLike, s=None, axes=None, norm=None):
62
    return torch.fft.ifftn(a, s, dim=axes, norm=norm)
63

64

65
@normalizer
66
@upcast
67
def rfftn(a: ArrayLike, s=None, axes=None, norm=None):
68
    return torch.fft.rfftn(a, s, dim=axes, norm=norm)
69

70

71
@normalizer
72
@upcast
73
def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
74
    return torch.fft.irfftn(a, s, dim=axes, norm=norm)
75

76

77
@normalizer
78
@upcast
79
def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
80
    return torch.fft.fft2(a, s, dim=axes, norm=norm)
81

82

83
@normalizer
84
@upcast
85
def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
86
    return torch.fft.ifft2(a, s, dim=axes, norm=norm)
87

88

89
@normalizer
90
@upcast
91
def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
92
    return torch.fft.rfft2(a, s, dim=axes, norm=norm)
93

94

95
@normalizer
96
@upcast
97
def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
98
    return torch.fft.irfft2(a, s, dim=axes, norm=norm)
99

100

101
@normalizer
102
@upcast
103
def hfft(a: ArrayLike, n=None, axis=-1, norm=None):
104
    return torch.fft.hfft(a, n, dim=axis, norm=norm)
105

106

107
@normalizer
108
@upcast
109
def ihfft(a: ArrayLike, n=None, axis=-1, norm=None):
110
    return torch.fft.ihfft(a, n, dim=axis, norm=norm)
111

112

113
@normalizer
114
def fftfreq(n, d=1.0):
115
    return torch.fft.fftfreq(n, d)
116

117

118
@normalizer
119
def rfftfreq(n, d=1.0):
120
    return torch.fft.rfftfreq(n, d)
121

122

123
@normalizer
124
def fftshift(x: ArrayLike, axes=None):
125
    return torch.fft.fftshift(x, axes)
126

127

128
@normalizer
129
def ifftshift(x: ArrayLike, axes=None):
130
    return torch.fft.ifftshift(x, axes)
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.