pytorch

Форк
0
/
test_te_fuser_pass.cpp 
402 строки · 13.2 Кб
1
#include <gtest/gtest.h>
2

3
#include <test/cpp/tensorexpr/test_base.h>
4
#include <torch/csrc/jit/codegen/fuser/interface.h>
5
#include <torch/csrc/jit/ir/ir.h>
6
#include <torch/csrc/jit/ir/irparser.h>
7
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
8
#include <torch/csrc/jit/runtime/interpreter.h>
9
#include <torch/csrc/jit/testing/file_check.h>
10
#include <sstream>
11

12
namespace torch {
13
namespace jit {
14

15
using namespace torch::jit::tensorexpr;
16

17
struct WithCPUFuser {
18
  WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {
19
    overrideCanFuseOnCPU(val);
20
  }
21

22
  ~WithCPUFuser() {
23
    overrideCanFuseOnCPU(cpuFuserEnabled);
24
  }
25

26
  bool cpuFuserEnabled;
27
};
28

29
TEST(TEFuserPass, FuserPass_1) {
30
  WithCPUFuser cf;
31
  const auto graph_string = R"IR(
32
    graph(%0 : Float(128, strides=[1], device=cpu),
33
          %1 : Float(128, strides=[1], device=cpu)):
34
      %12 : int = prim::Constant[value=1]()
35
      %2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
36
      %2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1)
37
      %3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12)
38
      %4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1)
39
      %5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12)
40
      return (%5))IR";
41
  auto g = std::make_shared<Graph>();
42
  torch::jit::parseIR(graph_string, g.get());
43

44
  g->lint();
45
  FuseTensorExprs(g);
46

47
  // We should not be able to fuse across the in-place operation here.
48
  testing::FileCheck()
49
      .check("prim::TensorExprGroup_")
50
      ->check("aten::add_")
51
      ->check("prim::TensorExprGroup_")
52
      ->run(*g);
53
}
54

55
TEST(TEFuserPass, FuserPass_2) {
56
  WithCPUFuser cf;
57
  const auto graph_string = R"IR(
58
    graph(%0 : Float(128, strides=[1], device=cpu),
59
          %1 : Float(128, strides=[1], device=cpu)):
60
      %12 : int = prim::Constant[value=1]()
61
      %a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
62
      %b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12)
63
      %c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12)
64
      %d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a)
65
      return (%d))IR";
66
  auto g = std::make_shared<Graph>();
67
  torch::jit::parseIR(graph_string, g.get());
68

69
  g->lint();
70
  FuseTensorExprs(g);
71

72
  // We should not be able to fuse across the in-place operation here.
73
  testing::FileCheck()
74
      .check("aten::add_")
75
      ->check("prim::TensorExprGroup_0")
76
      ->run(*g);
77
}
78

79
TEST(TEFuserPass, FuserPass_3) {
80
  WithCPUFuser cf;
81
  const auto graph_string = R"IR(
82
    graph(%x : Float(128, strides=[1], device=cpu),
83
          %y : Float(128, strides=[1], device=cpu)):
84
      %r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y)
85
      return (%r))IR";
86
  {
87
    auto g = std::make_shared<Graph>();
88
    torch::jit::parseIR(graph_string, g.get());
89

90
    g->lint();
91
    FuseTensorExprs(g, /* min_group_size= */ 2);
92

93
    // We should not create a fusion group since its size would be too small
94
    testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
95
  }
96
  {
97
    auto g = std::make_shared<Graph>();
98
    torch::jit::parseIR(graph_string, g.get());
99

100
    g->lint();
101
    FuseTensorExprs(g, /* min_group_size= */ 1);
102

103
    // We should create a fusion group since its size is above the threshold
104
    testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
105
  }
106
}
107

108
TEST(TEFuserPass, FuserPass_0DimInput) {
109
  WithCPUFuser cf;
110
  const auto graph_string = R"IR(
111
    graph(%x : Float(device=cpu),
112
          %y : Float(device=cpu)):
113
      %one : int = prim::Constant[value=1]()
114
      %a : Float(device=cpu) = aten::mul(%x, %y)
115
      %b : Float(device=cpu) = aten::add(%x, %a, %one)
116
      return (%b))IR";
117
  auto g = std::make_shared<Graph>();
118
  torch::jit::parseIR(graph_string, g.get());
119

120
  g->lint();
121
  FuseTensorExprs(g);
122

123
  // We should fuse 0-dim tensors too
124
  testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
125
}
126

127
TEST(TEFuserPass, FuserPass_UnfusibleDevice) {
128
  WithCPUFuser cf(false);
129
  const auto graph_string = R"IR(
130
    graph(%x : Float(10, strides=[1], device=cpu),
131
          %y : Float(10, strides=[1], device=cpu)):
132
      %a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
133
      return (%a))IR";
134
  auto g = std::make_shared<Graph>();
135
  torch::jit::parseIR(graph_string, g.get());
136

137
  g->lint();
138
  FuseTensorExprs(g, /* min_group_size= */ 1);
139

140
  // Test that we're not starting fusion groups from nodes with unfusible device
141
  testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
142
}
143

144
TEST(TEFuserPass, FuserPass_UnknownShapes) {
145
  WithCPUFuser cf;
146
  const auto graph_string = R"IR(
147
    graph(%x : Tensor,
148
          %y : Tensor):
149
      %a : Tensor = aten::mul(%x, %y)
150
      %b : Tensor = aten::mul(%x, %a)
151
      return (%b))IR";
152
  auto g = std::make_shared<Graph>();
153
  torch::jit::parseIR(graph_string, g.get());
154

155
  g->lint();
156
  FuseTensorExprs(g);
157

158
  // Test that we're not generating fusion groups when shapes are not known
159
  testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
160
}
161

162
TEST(TEFuserPass, FuserPass_Multidevice) {
163
  {
164
    WithCPUFuser cf;
165
    const auto graph_string = R"IR(
166
    graph(%x : Float(10, strides=[1], device=cpu),
167
          %y : Float(20, strides=[1], device=cpu),
168
          %z : Float(30, strides=[1], device=cpu)):
169
      %dim : int = prim::Constant[value=0]()
170
      %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
171
      %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
172
      return (%cat))IR";
173
    auto g = std::make_shared<Graph>();
174
    torch::jit::parseIR(graph_string, g.get());
175

176
    g->lint();
177
    FuseTensorExprs(g, /* min_group_size= */ 1);
178

179
    // We should be able to fuse this
180
    testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
181
  }
182
  {
183
    WithCPUFuser cf;
184
    const auto graph_string = R"IR(
185
    graph(%x : Float(10, strides=[1], device=cpu),
186
          %y : Float(20, strides=[1], device=cuda:0),
187
          %z : Float(30, strides=[1], device=cpu)):
188
      %dim : int = prim::Constant[value=0]()
189
      %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
190
      %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
191
      return (%cat))IR";
192
    auto g = std::make_shared<Graph>();
193
    torch::jit::parseIR(graph_string, g.get());
194

195
    g->lint();
196
    FuseTensorExprs(g, /* min_group_size= */ 1);
197

198
    // We should not fuse this aten::cat since its inputs are from different
199
    // devices
200
    testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
201
  }
202
  {
203
    WithCPUFuser cf;
204
    const auto graph_string = R"IR(
205
    graph(%x : Float(10, strides=[1], device=cpu),
206
          %y : Float(20, strides=[1], device=cpu),
207
          %z : Float(10, strides=[1], device=cuda:0)):
208
      %dim : int = prim::Constant[value=0]()
209
      %xy_list : Tensor[] = prim::ListConstruct(%x, %y)
210
      %xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
211
      %r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z)
212
      return (%r))IR";
213
    auto g = std::make_shared<Graph>();
214
    torch::jit::parseIR(graph_string, g.get());
215

216
    g->lint();
217
    FuseTensorExprs(g, /* min_group_size= */ 2);
218

219
    // Test that we check device before merging one node (cat) into another
220
    // (mul)
221
    testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
222
  }
223
  {
224
    WithCPUFuser cf;
225
    const auto graph_string = R"IR(
226
    graph(%x : Float(10, strides=[1], device=cpu),
227
          %y : Float(20, strides=[1], device=cpu),
228
          %z : Float(10, strides=[1], device=cuda:0)):
229
      %z2 : Tensor = aten::mul(%z, %z)
230
      %dim : int = prim::Constant[value=0]()
231
      %xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2)
232
      %cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
233
      return (%cat))IR";
234
    auto g = std::make_shared<Graph>();
235
    torch::jit::parseIR(graph_string, g.get());
236

237
    g->lint();
238
    FuseTensorExprs(g, /* min_group_size= */ 2);
239

240
    // Test that we check device before merging one node (mul) into another
241
    // (cat)
242
    testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
243
  }
244
  {
245
    WithCPUFuser cf;
246
    const auto graph_string = R"IR(
247
    graph(%x : Float(10, strides=[1], device=cpu),
248
          %y : Float(20, strides=[1], device=cuda:0)):
249
      %r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
250
      return (%r))IR";
251
    auto g = std::make_shared<Graph>();
252
    torch::jit::parseIR(graph_string, g.get());
253

254
    g->lint();
255
    FuseTensorExprs(g, /* min_group_size= */ 1);
256

257
    // We should not fuse this graph since its inputs are from different devices
258
    testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
259
  }
260
  {
261
    WithCPUFuser cf;
262
    const auto graph_string = R"IR(
263
    graph(%x : Float(10, strides=[1], device=cuda:0),
264
          %y : Float(20, strides=[1], device=cuda:1),
265
          %z : Float(20, strides=[1], device=cpu)):
266
      %x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x)
267
      %y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y)
268
      %z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z)
269
      return (%x2, %y2, %z2))IR";
270
    auto g = std::make_shared<Graph>();
271
    torch::jit::parseIR(graph_string, g.get());
272

273
    g->lint();
274
    FuseTensorExprs(g, /* min_group_size= */ 2);
275

276
    // We should not fuse these two computations since they use different
277
    // devices
278
    testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
279
  }
280
}
281

282
TEST(TEFuserPass, FuserPass_MergeGroups) {
283
  WithCPUFuser cf;
284
  const auto graph_string = R"IR(
285
    graph(%a : Float(128, strides=[1], device=cpu),
286
          %b : Float(128, strides=[1], device=cpu)):
287
      %x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a)
288
      %y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b)
289
      return (%x, %y))IR";
290
  auto g = std::make_shared<Graph>();
291
  torch::jit::parseIR(graph_string, g.get());
292

293
  g->lint();
294
  FuseTensorExprs(g, /* min_group_size= */ 1);
295

296
  // The %x and %y computations are completely independent and yet we should put
297
  // them into a single fusion group rather than having two separate ones.
298
  testing::FileCheck()
299
      .check("= prim::TensorExprGroup_")
300
      ->check_not("= prim::TensorExprGroup_")
301
      ->run(*g);
302
}
303

304
TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) {
305
  WithCPUFuser cf;
306
  const auto graph_string = R"IR(
307
    graph(%x : Bool(8, strides=[1], device=cpu),
308
          %y : Bool(8, strides=[1], device=cpu)):
309
      %a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y)
310
      %b : Tensor = aten::__or__(%a, %y)
311
      return (%b)
312
    )IR";
313
  auto g = std::make_shared<Graph>();
314
  torch::jit::parseIR(graph_string, g.get());
315
  g->lint();
316
  FuseTensorExprs(g, /* min_group_size= */ 2);
317
  testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
318
}
319

320
TEST(TEFuserPass, FuserPass_Where) {
321
  WithCPUFuser cf;
322
  const auto graph_string = R"IR(
323
    graph(%x : Float(8, strides=[1], device=cpu),
324
          %y : Float(8, strides=[1], device=cpu),
325
          %z : Float(8, strides=[1], device=cpu)):
326
      %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
327
      %b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z)
328
      return (%b)
329
    )IR";
330
  auto g = std::make_shared<Graph>();
331
  torch::jit::parseIR(graph_string, g.get());
332
  g->lint();
333
  FuseTensorExprs(g, /* min_group_size= */ 2);
334
  testing::FileCheck().check("prim::TensorExprGroup")->run(*g);
335
}
336

337
TEST(TEFuserPass, FuserPass_WhereList) {
338
  WithCPUFuser cf;
339
  const auto graph_string = R"IR(
340
    graph(%x : Float(8, strides=[1], device=cpu),
341
          %y : Float(8, strides=[1], device=cpu),
342
          %z : Float(8, strides=[1], device=cpu)):
343
      %cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
344
      %b : Tensor[] = aten::where(%cond)
345
      return (%b)
346
    )IR";
347
  auto g = std::make_shared<Graph>();
348
  torch::jit::parseIR(graph_string, g.get());
349
  g->lint();
350
  FuseTensorExprs(g, /* min_group_size= */ 2);
351
  testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);
352
}
353

354
TEST(TEFuserPass, DynamicShapeFusion) {
355
  WithCPUFuser cf;
356
  const auto graph_string = R"IR(
357
    graph(%0 : Float(10, 5, strides=[5, 1], device=cpu),
358
          %1 : Float(10, 5, strides=[5, 1], device=cpu)):
359
      %2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1)
360
      %3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1)
361
      return (%3))IR";
362
  auto g = std::make_shared<Graph>();
363
  torch::jit::parseIR(graph_string, g.get());
364

365
  g->lint();
366
  FuseTensorExprs(
367
      g,
368
      /* min_group_size = */ 2,
369
      /* add_composed_op = */ true,
370
      /* fuse_to_dynamic_shapes = */ true);
371
  Code code(g, "");
372

373
  testing::FileCheck()
374
      .check("prim::TensorExprDynamicGroup_")
375
      ->check("prim::TensorExprDynamicGuard")
376
      ->check("prim::TensorExprGroup_")
377
      ->run(*g);
378

379
  auto run_and_compare = [&](const std::vector<at::Tensor>& inputs) {
380
    TORCH_INTERNAL_ASSERT(inputs.size() == 2);
381

382
    auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]);
383

384
    InterpreterState interp(code);
385
    Stack stack(inputs.begin(), inputs.end());
386
    interp.run(stack);
387
    at::Tensor out = pop(stack).toTensor();
388
    ASSERT_TRUE(at::allclose(out, ref));
389
  };
390

391
  std::vector<at::Tensor> inputs = {at::rand({10, 5}), at::rand({10, 5})};
392
  run_and_compare(inputs);
393

394
  std::vector<at::Tensor> inputs2 = {at::rand({20, 5}), at::rand({20, 5})};
395
  run_and_compare(inputs2);
396

397
  std::vector<at::Tensor> inputs3 = {at::rand({25, 60}), at::rand({25, 60})};
398
  run_and_compare(inputs3);
399
}
400

401
} // namespace jit
402
} // namespace torch
403

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.