pytorch

Форк
0
/
test_conv.cpp 
234 строки · 6.7 Кб
1
#include <gtest/gtest.h>
2
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
3
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
4
#include <torch/csrc/jit/tensorexpr/loopnest.h>
5
#include <torch/csrc/jit/tensorexpr/operators/conv2d.h>
6
#include <torch/csrc/jit/tensorexpr/tensor.h>
7
#include <torch/torch.h>
8

9
namespace torch {
10
namespace jit {
11

12
namespace te = torch::jit::tensorexpr;
13
namespace F = torch::nn::functional;
14

15
#ifdef TORCH_ENABLE_LLVM
16

17
// Generate test data with few bits of precision, to minimize error
18
// accumulation from floating-point reordering.
19
static at::Tensor genTestData(c10::IntArrayRef args) {
20
  return at::trunc(at::randn(args) * 256.0f) / 256.0f;
21
}
22

23
TEST(Conv, DepthwiseConv2D) {
24
  constexpr int N = 1, C = 72, H = 56, W = 56;
25
  constexpr int K = 72, R = 3, S = 3;
26
  constexpr int kPad = 1, kStride = 2, kGroups = C;
27
  constexpr int CperG = C / kGroups;
28

29
  te::BufHandle input("input", {N, C, H, W}, te::kFloat);
30
  te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat);
31
  te::BufHandle bias("bias", {K}, te::kFloat);
32
  te::Tensor output =
33
      te::conv2d_depthwise(input, weight, bias, kStride, kPad, kGroups);
34

35
  te::LoopNest loop({output});
36
  loop.simplify();
37
  loop.prepareForCodegen();
38
  te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output});
39

40
  auto it = genTestData({N, C, H, W});
41
  auto wt = genTestData({K, CperG, R, S});
42
  auto bt = genTestData({K});
43
  auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups);
44
  auto ot = at::zeros_like(ref);
45
  cg.call(
46
      {it.data_ptr<float>(),
47
       wt.data_ptr<float>(),
48
       bt.data_ptr<float>(),
49
       ot.data_ptr<float>()});
50

51
  ASSERT_TRUE(at::allclose(ref, ot));
52
}
53

54
TEST(Conv, DepthwiseConv2DNoBias) {
55
  constexpr int N = 1, C = 72, H = 56, W = 56;
56
  constexpr int K = 72, R = 3, S = 3;
57
  constexpr int kPad = 1, kStride = 2, kGroups = C;
58
  constexpr int CperG = C / kGroups;
59

60
  te::BufHandle input("input", {N, C, H, W}, te::kFloat);
61
  te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat);
62
  te::Tensor output =
63
      te::conv2d_depthwise(input, weight, kStride, kPad, kGroups);
64

65
  te::LoopNest loop({output});
66
  loop.simplify();
67
  loop.prepareForCodegen();
68
  te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, output});
69

70
  auto it = genTestData({N, C, H, W});
71
  auto wt = genTestData({K, CperG, R, S});
72
  auto ref =
73
      at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups);
74
  auto ot = at::zeros_like(ref);
75
  cg.call({it.data_ptr<float>(), wt.data_ptr<float>(), ot.data_ptr<float>()});
76

77
  ASSERT_TRUE(at::allclose(ref, ot));
78
}
79

80
TEST(Conv, DepthwiseConv2DDynamicShapes) {
81
  te::VarHandle N_var("N", te::kInt);
82
  te::VarHandle C_var("C", te::kInt);
83
  te::VarHandle H_var("H", te::kInt);
84
  te::VarHandle W_var("W", te::kInt);
85
  te::VarHandle K_var("K", te::kInt);
86
  te::VarHandle CperG_var("CperG", te::kInt);
87
  te::VarHandle R_var("R", te::kInt);
88
  te::VarHandle S_var("S", te::kInt);
89
  te::VarHandle kPad_var("kPad", te::kInt);
90
  te::VarHandle kStride_var("kStride", te::kInt);
91
  te::VarHandle kGroups_var("kGroups", te::kInt);
92

93
  te::BufHandle input("input", {N_var, C_var, H_var, W_var}, te::kFloat);
94
  te::BufHandle weight("weight", {K_var, CperG_var, R_var, S_var}, te::kFloat);
95
  te::Tensor output = te::conv2d_depthwise(
96
      input,
97
      weight,
98
      N_var,
99
      C_var,
100
      H_var,
101
      W_var,
102
      K_var,
103
      CperG_var,
104
      R_var,
105
      S_var,
106
      kStride_var,
107
      kPad_var,
108
      kGroups_var);
109

110
  te::LoopNest loop({output});
111
  loop.simplify();
112
  loop.prepareForCodegen();
113
  std::vector<te::CodeGen::BufferArg> buffer_args = {
114
      input,
115
      weight,
116
      N_var,
117
      C_var,
118
      H_var,
119
      W_var,
120
      K_var,
121
      CperG_var,
122
      R_var,
123
      S_var,
124
      kPad_var,
125
      kStride_var,
126
      kGroups_var,
127
      output};
128
  te::LLVMCodeGen cg(loop.root_stmt(), buffer_args);
129

130
  constexpr int N = 1, C = 72, H = 56, W = 56;
131
  constexpr int K = 72, R = 3, S = 3;
132
  constexpr int kPad = 1, kStride = 2, kGroups = C;
133
  constexpr int CperG = C / kGroups;
134

135
  auto it = genTestData({N, C, H, W});
136
  auto wt = genTestData({K, CperG, R, S});
137
  auto ref =
138
      at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups);
139
  auto ot = at::zeros_like(ref);
140
  std::vector<te::CodeGen::CallArg> call_args = {
141
      it.data_ptr<float>(),
142
      wt.data_ptr<float>(),
143
      N,
144
      C,
145
      H,
146
      W,
147
      K,
148
      CperG,
149
      R,
150
      S,
151
      kPad,
152
      kStride,
153
      kGroups,
154
      ot.data_ptr<float>()};
155
  cg.call(call_args);
156

157
  ASSERT_TRUE(at::allclose(ref, ot));
158
}
159

160
#endif
161

162
TEST(Conv, Conv2D) {
163
  // Input dimensions.
164
  constexpr int N = 1;
165
  constexpr int C = 3;
166
  constexpr int H = 11;
167
  constexpr int W = 11;
168

169
  // Filter dimensions.
170
  constexpr int K = 8;
171
  constexpr int R = 3;
172
  constexpr int S = 3;
173

174
  // Output dims.
175
  constexpr int OH = H - R + 1;
176
  constexpr int OW = W - S + 1;
177

178
  // Compute reference result.
179
  at::Tensor input = torch::randn({N, C, H, W});
180
  at::Tensor filter = torch::randn({K, C, R, S});
181
  at::Tensor ref = F::conv2d(input, filter);
182

183
  // Double check the output size is as expected.
184
  ASSERT_EQ(ref.size(0), N);
185
  ASSERT_EQ(ref.size(1), K);
186
  ASSERT_EQ(ref.size(2), OH);
187
  ASSERT_EQ(ref.size(3), OW);
188

189
  te::BufHandle inputB("input", {N, C, H, W}, te::kFloat);
190
  te::BufHandle filterB("filter", {K, C, R, S}, te::kFloat);
191

192
  te::Tensor conv = te::Reduce(
193
      "conv",
194
      {N, K, OH, OW},
195
      te::Sum(),
196
      // FIXME: We have to use a `std::vector` parameter here and then unpack
197
      // it, because we don't have an overload allowing for an arbitrary number
198
      // of ExprHandle/VarHandle parameters.
199
      [&](const std::vector<te::VarHandle>& v) {
200
        auto const& n = v[0];
201
        auto const& k = v[1];
202
        auto const& oh = v[2];
203
        auto const& ow = v[3];
204
        auto const& c = v[4];
205
        auto const& r = v[5];
206
        auto const& s = v[6];
207
        // FIXME: We have to use `call` and construct a `std::vector` here
208
        // because the `operator()` overload is only specialized for a small
209
        // number of arguments.
210
        return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s);
211
      },
212
      // FIXME: If you forget one of the reduction dims, you get a segfault.
213
      // Could that be caught by a verifier?
214
      {C, R, S});
215

216
  // FIXME: It'd be nice to have a single header that pulls in things like
217
  // LoopNest, IRSimplifier, etc.
218
  te::LoopNest loop({conv});
219
  loop.prepareForCodegen();
220
  te::StmtPtr s = loop.root_stmt();
221
  s = te::IRSimplifier::simplify(s);
222

223
  at::Tensor result = at::empty_like(ref);
224
  te::SimpleIREvaluator cg(s, {inputB, filterB, conv});
225
  cg.call(
226
      {input.data_ptr<float>(),
227
       filter.data_ptr<float>(),
228
       result.data_ptr<float>()});
229

230
  ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3));
231
}
232

233
} // namespace jit
234
} // namespace torch
235

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.