pytorch

Форк
0
/
fft.cpp 
130 строк · 4.3 Кб
1
#include <gtest/gtest.h>
2

3
#include <c10/util/irange.h>
4
#include <test/cpp/api/support.h>
5
#include <torch/torch.h>
6

7
// Naive DFT of a 1 dimensional tensor
8
torch::Tensor naive_dft(torch::Tensor x, bool forward = true) {
9
  TORCH_INTERNAL_ASSERT(x.dim() == 1);
10
  x = x.contiguous();
11
  auto out_tensor = torch::zeros_like(x);
12
  const int64_t len = x.size(0);
13

14
  // Roots of unity, exp(-2*pi*j*n/N) for n in [0, N), reversed for inverse
15
  // transform
16
  std::vector<c10::complex<double>> roots(len);
17
  const auto angle_base = (forward ? -2.0 : 2.0) * M_PI / len;
18
  for (const auto i : c10::irange(len)) {
19
    auto angle = i * angle_base;
20
    roots[i] = c10::complex<double>(std::cos(angle), std::sin(angle));
21
  }
22

23
  const auto in = x.data_ptr<c10::complex<double>>();
24
  const auto out = out_tensor.data_ptr<c10::complex<double>>();
25
  for (const auto i : c10::irange(len)) {
26
    for (const auto j : c10::irange(len)) {
27
      out[i] += roots[(j * i) % len] * in[j];
28
    }
29
  }
30
  return out_tensor;
31
}
32

33
// NOTE: Visual Studio and ROCm builds don't understand complex literals
34
//   as of August 2020
35

36
TEST(FFTTest, fft) {
37
  auto t = torch::randn(128, torch::kComplexDouble);
38
  auto actual = torch::fft::fft(t);
39
  auto expect = naive_dft(t);
40
  ASSERT_TRUE(torch::allclose(actual, expect));
41
}
42

43
TEST(FFTTest, fft_real) {
44
  auto t = torch::randn(128, torch::kDouble);
45
  auto actual = torch::fft::fft(t);
46
  auto expect = torch::fft::fft(t.to(torch::kComplexDouble));
47
  ASSERT_TRUE(torch::allclose(actual, expect));
48
}
49

50
TEST(FFTTest, fft_pad) {
51
  auto t = torch::randn(128, torch::kComplexDouble);
52
  auto actual = torch::fft::fft(t, 200);
53
  auto expect = torch::fft::fft(torch::constant_pad_nd(t, {0, 72}));
54
  ASSERT_TRUE(torch::allclose(actual, expect));
55

56
  actual = torch::fft::fft(t, 64);
57
  expect = torch::fft::fft(torch::constant_pad_nd(t, {0, -64}));
58
  ASSERT_TRUE(torch::allclose(actual, expect));
59
}
60

61
TEST(FFTTest, fft_norm) {
62
  auto t = torch::randn(128, torch::kComplexDouble);
63
  // NOLINTNEXTLINE(bugprone-argument-comment)
64
  auto unnorm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/{});
65
  // NOLINTNEXTLINE(bugprone-argument-comment)
66
  auto norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"forward");
67
  ASSERT_TRUE(torch::allclose(unnorm / 128, norm));
68

69
  // NOLINTNEXTLINE(bugprone-argument-comment)
70
  auto ortho_norm = torch::fft::fft(t, /*n=*/{}, /*axis=*/-1, /*norm=*/"ortho");
71
  ASSERT_TRUE(torch::allclose(unnorm / std::sqrt(128), ortho_norm));
72
}
73

74
TEST(FFTTest, ifft) {
75
  auto T = torch::randn(128, torch::kComplexDouble);
76
  auto actual = torch::fft::ifft(T);
77
  auto expect = naive_dft(T, /*forward=*/false) / 128;
78
  ASSERT_TRUE(torch::allclose(actual, expect));
79
}
80

81
TEST(FFTTest, fft_ifft) {
82
  auto t = torch::randn(77, torch::kComplexDouble);
83
  auto T = torch::fft::fft(t);
84
  ASSERT_EQ(T.size(0), 77);
85
  ASSERT_EQ(T.scalar_type(), torch::kComplexDouble);
86

87
  auto t_round_trip = torch::fft::ifft(T);
88
  ASSERT_EQ(t_round_trip.size(0), 77);
89
  ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble);
90
  ASSERT_TRUE(torch::allclose(t, t_round_trip));
91
}
92

93
TEST(FFTTest, rfft) {
94
  auto t = torch::randn(129, torch::kDouble);
95
  auto actual = torch::fft::rfft(t);
96
  auto expect = torch::fft::fft(t.to(torch::kComplexDouble)).slice(0, 0, 65);
97
  ASSERT_TRUE(torch::allclose(actual, expect));
98
}
99

100
TEST(FFTTest, rfft_irfft) {
101
  auto t = torch::randn(128, torch::kDouble);
102
  auto T = torch::fft::rfft(t);
103
  ASSERT_EQ(T.size(0), 65);
104
  ASSERT_EQ(T.scalar_type(), torch::kComplexDouble);
105

106
  auto t_round_trip = torch::fft::irfft(T);
107
  ASSERT_EQ(t_round_trip.size(0), 128);
108
  ASSERT_EQ(t_round_trip.scalar_type(), torch::kDouble);
109
  ASSERT_TRUE(torch::allclose(t, t_round_trip));
110
}
111

112
TEST(FFTTest, ihfft) {
113
  auto T = torch::randn(129, torch::kDouble);
114
  auto actual = torch::fft::ihfft(T);
115
  auto expect = torch::fft::ifft(T.to(torch::kComplexDouble)).slice(0, 0, 65);
116
  ASSERT_TRUE(torch::allclose(actual, expect));
117
}
118

119
TEST(FFTTest, hfft_ihfft) {
120
  auto t = torch::randn(64, torch::kComplexDouble);
121
  t[0] = .5; // Must be purely real to satisfy hermitian symmetry
122
  auto T = torch::fft::hfft(t, 127);
123
  ASSERT_EQ(T.size(0), 127);
124
  ASSERT_EQ(T.scalar_type(), torch::kDouble);
125

126
  auto t_round_trip = torch::fft::ihfft(T);
127
  ASSERT_EQ(t_round_trip.size(0), 64);
128
  ASSERT_EQ(t_round_trip.scalar_type(), torch::kComplexDouble);
129
  ASSERT_TRUE(torch::allclose(t, t_round_trip));
130
}
131

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.