pytorch

Форк
0
/
test_kernel.cpp 
2133 строки · 75.8 Кб
1
#include <gtest/gtest.h>
2

3
#include <ATen/code_template.h>
4
#include <c10/util/irange.h>
5
#include <test/cpp/tensorexpr/test_base.h>
6
#include <torch/csrc/jit/ir/ir.h>
7
#include <torch/csrc/jit/ir/irparser.h>
8
#include <torch/csrc/jit/passes/constant_propagation.h>
9
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
10
#include <torch/csrc/jit/tensorexpr/kernel.h>
11
#include <torch/csrc/jit/tensorexpr/loopnest.h>
12
#include <torch/csrc/jit/tensorexpr/tensor.h>
13
#include <torch/csrc/jit/testing/file_check.h>
14
#include <torch/torch.h>
15
#include <cmath>
16
#include <sstream>
17
#include <stdexcept>
18

19
namespace torch {
20
namespace jit {
21

22
using namespace torch::indexing;
23
using namespace torch::jit::tensorexpr;
24

25
class Kernel : public ::testing::Test {
26
 public:
27
  void SetUp() override {
28
    getTEMustUseLLVMOnCPU() = false;
29
  }
30
};
31

32
TEST_F(Kernel, ParallelExternalCallBuf) {
33
  const auto graph_string = R"IR(
34
    graph(%0 : Float(1000, 5000, strides=[5000, 1], device=cpu),
35
          %1 : Float(1000, 5000, strides=[5000, 1], device=cpu),
36
          %2 : Float(5000, 1000, strides=[5000, 1], device=cpu)):
37
      %3 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::mul(%0, %1)
38
      %4 : Float(1000, 5000, strides=[5000, 1], device=cpu) = aten::matmul(%3, %2)
39
      return (%4))IR";
40
  auto graph = std::make_shared<Graph>();
41
  torch::jit::parseIR(graph_string, &*graph);
42
  const std::string& verification_pattern =
43
      R"IR(
44
# CHECK: for (int64_t i = 0ll; i < 5000ll; i++)  /* parallel */{)IR";
45

46
#ifdef TORCH_ENABLE_LLVM
47
  TensorExprKernel k(graph);
48
  StmtPtr s = k.getCodeGenStmt();
49
  std::ostringstream oss;
50
  oss << *s;
51
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
52
#endif
53
}
54

55
TEST_F(Kernel, InliningIntermediates) {
56
  // here, each mul has only one use, so it should be completely inlined
57
  {
58
    const auto graph_string = R"IR(
59
        graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
60
              %1 : Float(5, 3, strides=[3, 1], device=cpu)):
61
          %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
62
          %one : int = prim::Constant[value=1]()
63
          %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
64
          %5: Float(5, 3, strides=[3, 1]) = aten::add(%4, %1, %one)
65
          return (%5))IR";
66
    auto graph = std::make_shared<Graph>();
67
    parseIR(graph_string, &*graph);
68
    TensorExprKernel k(graph);
69
    auto stmt = k.getCodeGenStmt();
70
    std::ostringstream oss;
71
    oss << *stmt;
72
    torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str());
73
  }
74
  {
75
    const auto graph_template = R"IR(
76
        graph(%0 : Float(5, 3, strides=[3, 1], device=${device}),
77
              %1 : Float(5, 3, strides=[3, 1], device=${device})):
78
          %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
79
          %one : int = prim::Constant[value=1]()
80
          %3 : Float(5, 3, strides=[3, 1]) = aten::sub(%0, %2, %one)
81
          %4 : Float(5, 3, strides=[3, 1]) = aten::add(%3, %0, %one)
82
          %5 : Float(5, 3, strides=[3, 1]) = aten::div(%3, %0)
83
          return (%4, %5))IR";
84
    for (bool use_cuda : {false, true}) {
85
      if (!torch::cuda::is_available() && use_cuda) {
86
        continue;
87
      }
88

89
      at::jit::TemplateEnv env;
90
      env.s("device", use_cuda ? "cuda:0" : "cpu");
91
      const auto graph_string = format(graph_template, env);
92
      auto graph = std::make_shared<Graph>();
93
      parseIR(graph_string, &*graph);
94
      TensorExprKernel k(graph);
95
      auto stmt = k.getCodeGenStmt();
96
      std::ostringstream oss;
97
      oss << *stmt;
98
      // aten_mul only has one use, inlined completely
99
      torch::jit::testing::FileCheck().check_not("aten_mul")->run(oss.str());
100

101
      // aten_sub should be removed by the CUDA backend by metavar rewriting
102
      // and by the CPU backend by horizontal fusion.
103
      torch::jit::testing::FileCheck().check_not("aten_sub")->run(oss.str());
104
    }
105
  }
106
}
107

108
TEST_F(Kernel, PreAllocIntermediateBufs) {
109
  const auto graph_string = R"IR(
110
graph(%a.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu),
111
      %b.1 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu)):
112
  %2 : int = prim::Constant[value=1]()
113
  %c.2 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::matmul(%a.1, %b.1) # test_matmul.py:12:12
114
  %3 : Float(8, 8, strides=[8, 1], requires_grad=0, device=cpu) = aten::add(%a.1, %c.2, %2) # test_matmul.py:13:15
115
  return (%3))IR";
116
  auto graph = std::make_shared<Graph>();
117
  parseIR(graph_string, &*graph);
118

119
  auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
120
  auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
121
  auto o = at::zeros({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
122
  auto ref = at::matmul(a, b) + a;
123
  TensorExprKernel k(graph, {}, {}, true);
124

125
  std::vector<at::Tensor> inputs = {a, b};
126
  auto stmt = k.getCodeGenStmt();
127

128
  std::ostringstream oss;
129
  oss << *stmt;
130

131
  // Check whether the intermediate buffer has been added to constants
132
  auto constants = k.getConstantDescriptors();
133
  ASSERT_EQ(constants.size(), 1);
134

135
  // Check the IR we produced
136
  torch::jit::testing::FileCheck().check_not("Alloc")->run(oss.str());
137
  torch::jit::testing::FileCheck().check_not("Free")->run(oss.str());
138

139
  // Check correctness
140
  std::vector<IValue> stack = fmap<IValue>(inputs);
141
  k.run(stack);
142
  o = stack[0].toTensor();
143
  ASSERT_TRUE(at::allclose(o, ref));
144
}
145

146
TEST_F(Kernel, _1) {
147
  const auto graph_string = R"IR(
148
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
149
            %1 : Float(5, 3, strides=[3, 1], device=cpu)):
150
        %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
151
        %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
152
        return (%3))IR";
153
  auto graph = std::make_shared<Graph>();
154
  parseIR(graph_string, &*graph);
155

156
  auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
157
  auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
158
  auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
159
  auto ref = a * (a * b);
160
  TensorExprKernel k(graph);
161
  std::vector<at::Tensor> inputs = {a, b};
162
  StmtPtr s = k.getCodeGenStmt();
163

164
  std::ostringstream oss;
165
  oss << *s;
166

167
  // Check the IR we produced
168
  const std::string& verification_pattern =
169
      R"IR(
170
# CHECK: for
171
# CHECK-NEXT: for
172
# CHECK-NOT: for)IR";
173
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
174

175
  std::vector<IValue> stack = fmap<IValue>(inputs);
176
  k.run(stack);
177
  o = stack[0].toTensor();
178
  for (size_t i = 0; i < 5 * 3; i++) {
179
    TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
180
  }
181
}
182

183
TEST_F(Kernel, _2) {
184
  const auto graph_string = R"IR(
185
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
186
            %1 : Float(5, 3, strides=[1, 5], device=cpu)):
187
        %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
188
        %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
189
        return (%3))IR";
190
  auto graph = std::make_shared<Graph>();
191
  parseIR(graph_string, &*graph);
192

193
  auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
194
  auto b =
195
      at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
196
  auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
197
  auto ref = a * (a * b);
198
  TensorExprKernel k(graph);
199
  std::vector<at::Tensor> inputs = {a, b};
200
  StmtPtr s = k.getCodeGenStmt();
201

202
  std::ostringstream oss;
203
  oss << *s;
204

205
  // Check the IR we produced
206
  const std::string& verification_pattern =
207
      R"IR(
208
# CHECK: for
209
# CHECK-NEXT: for
210
# CHECK-NOT: for)IR";
211
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
212

213
  std::vector<IValue> stack = fmap<IValue>(inputs);
214
  k.run(stack);
215
  o = stack[0].toTensor();
216
  for (size_t i = 0; i < 5 * 3; i++) {
217
    TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
218
  }
219
}
220

221
TEST_F(Kernel, _3) {
222
  const auto graph_string = R"IR(
223
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
224
            %1 : Float(5, 3, strides=[12, 2], device=cpu)):
225
        %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
226
        %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
227
        return (%3))IR";
228
  auto graph = std::make_shared<Graph>();
229
  parseIR(graph_string, &*graph);
230

231
  auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
232
  auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat))
233
               .index({Slice(None, None, 2), Slice(None, None, 2)});
234
  auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
235
  auto ref = a * (a * b);
236
  TensorExprKernel k(graph);
237
  std::vector<at::Tensor> inputs = {a, b};
238
  StmtPtr s = k.getCodeGenStmt();
239

240
  std::ostringstream oss;
241
  oss << *s;
242

243
  // Check the IR we produced
244
  const std::string& verification_pattern =
245
      R"IR(
246
# CHECK: for
247
# CHECK-NEXT: for
248
# CHECK-NOT: for)IR";
249
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
250

251
  std::vector<IValue> stack = fmap<IValue>(inputs);
252
  k.run(stack);
253
  o = stack[0].toTensor();
254
  for (size_t i = 0; i < 5 * 3; i++) {
255
    TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
256
  }
257
}
258

259
TEST_F(Kernel, Huge) {
260
  const auto graph_string = R"IR(
261
      graph(%x.1 : Float(4000000000, strides=[1], requires_grad=0, device=cpu)):
262
        %1 : int = prim::Constant[value=0]()
263
        %2 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::unsqueeze(%x.1, %1)
264
        %3 : Float(1, 4000000000, strides=[4000000000, 1], requires_grad=0, device=cpu) = aten::relu(%2)
265
        return (%3))IR";
266
  auto graph = std::make_shared<Graph>();
267
  parseIR(graph_string, &*graph);
268
  TensorExprKernel k(graph);
269
  std::ostringstream oss;
270
  oss << *k.getCodeGenStmt();
271
  // The 4000000000 iterations loop will be split into 500000000 x 8 and the
272
  // outer loop will be parallel. If LLVM is not present, it will not be split,
273
  // and to cover both of these cases we're looking for 00000000ll; in the
274
  // output.
275
  const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR";
276
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
277
}
278

279
TEST_F(Kernel, ParallelStrided) {
280
  const auto graph_string = R"IR(
281
      graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu),
282
            %1 : Float(5, 3, 40005, strides=[960120, 160020, 2], device=cpu)):
283
        %2 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %1)
284
        %3 : Float(5, 3, 40005, strides=[120015, 40005, 1]) = aten::mul(%0, %2)
285
        return (%3))IR";
286
  auto graph = std::make_shared<Graph>();
287
  parseIR(graph_string, &*graph);
288

289
  auto a = at::rand({5, 3, 40005}, TensorOptions(kCPU).dtype(at::kFloat));
290
  auto b = at::rand({10, 6, 80010}, TensorOptions(kCPU).dtype(at::kFloat))
291
               .index(
292
                   {Slice(None, None, 2),
293
                    Slice(None, None, 2),
294
                    Slice(None, None, 2)});
295
  auto ref = a * (a * b);
296
  auto o = at::zeros_like(ref);
297
  TensorExprKernel k(graph);
298
  std::vector<at::Tensor> inputs = {a, b};
299
  std::vector<IValue> stack = fmap<IValue>(inputs);
300
  k.run(stack);
301
  o = stack[0].toTensor();
302
  for (size_t i = 0; i < 5 * 3; i++) {
303
    TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
304
  }
305
}
306

307
TEST_F(Kernel, DISABLED_Shape_Inference) {
308
  // disabled: doesn't do stride propagation, and isn't being used currently
309

310
  // Test TensorExpr shape inference capabilities: it should only require shapes
311
  // for the inputs
312
  {
313
    const auto graph_string = R"IR(
314
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
315
            %1 : Float(5, 3, strides=[12, 2], device=cpu)):
316
        %2 : Tensor = aten::mul(%0, %1)
317
        %3 : Tensor = aten::mul(%0, %2)
318
        return (%3))IR";
319
    auto graph = std::make_shared<Graph>();
320
    parseIR(graph_string, &*graph);
321

322
    auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
323
    auto b = at::rand({10, 6}, TensorOptions(kCPU).dtype(at::kFloat))
324
                 .index({Slice(None, None, 2), Slice(None, None, 2)});
325
    auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
326
    auto ref = a * (a * b);
327
    TensorExprKernel k(graph);
328
    std::vector<at::Tensor> inputs = {a, b};
329
    StmtPtr s = k.getCodeGenStmt();
330

331
    std::ostringstream oss;
332
    oss << *s;
333

334
    // Check the IR we produced
335
    const std::string& verification_pattern =
336
        R"IR(
337
# CHECK: for
338
# CHECK-NEXT: for
339
# CHECK-NOT: for)IR";
340
    torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
341

342
    std::vector<IValue> stack = fmap<IValue>(inputs);
343
    k.run(stack);
344
    o = stack[0].toTensor();
345
    for (size_t i = 0; i < 5 * 3; i++) {
346
      TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
347
    }
348
  }
349
  {
350
    const auto graph_string = R"IR(
351
      graph(%0 : Float(8, 8, strides=[8, 1], device=cpu),
352
            %1 : Float(8, 8, strides=[8, 1], device=cpu)):
353
        %2 : Tensor = aten::mul(%0, %1)
354
        %3 : Tensor, %4 : Tensor = prim::ConstantChunk[dim=1,chunks=2](%2)
355
        %r : Tensor = aten::mul(%3, %4)
356
        return (%r))IR";
357
    auto graph = std::make_shared<Graph>();
358
    parseIR(graph_string, &*graph);
359

360
    auto a = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
361
    auto b = at::rand({8, 8}, TensorOptions(kCPU).dtype(at::kFloat));
362
    auto o = at::zeros({8, 4}, TensorOptions(kCPU).dtype(at::kFloat));
363
    auto t = torch::chunk(a * b, 2, 1);
364
    auto ref = t[0] * t[1];
365
    TensorExprKernel k(graph);
366
    std::vector<at::Tensor> inputs = {a, b};
367
    StmtPtr s = k.getCodeGenStmt();
368

369
    std::ostringstream oss;
370
    oss << *s;
371

372
    // Check the IR we produced
373
    const std::string& verification_pattern =
374
        R"IR(
375
# CHECK: for)IR";
376
    torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
377

378
    std::vector<IValue> stack = fmap<IValue>(inputs);
379
    k.run(stack);
380
    o = stack[0].toTensor();
381
    TORCH_CHECK_EQ(o.sizes()[0], 8);
382
    TORCH_CHECK_EQ(o.sizes()[1], 4);
383
    for (size_t i = 0; i < 8 * 4; i++) {
384
      TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
385
    }
386
  }
387
  {
388
    // Test that shape inference handles aten::unsqueeze
389

390
    const auto graph_string = R"IR(
391
      graph(%a : Float(4, 2, strides=[2, 1], device=cpu),
392
            %b : Float(4, 3, 2, strides=[6, 2, 1], device=cpu),
393
            %c : Float(3, 2, 2, strides=[4, 2, 1], device=cpu)):
394
        %one : int = prim::Constant[value=1]()
395
        %minus_one : int = prim::Constant[value=-1]()
396
        %three : int = prim::Constant[value=3]()
397
        %minus_four : int = prim::Constant[value=-4]()
398
        %a1 : Tensor = aten::unsqueeze(%a, %one)        # new size: [4,1,2]
399
        %a2 : Tensor = aten::unsqueeze(%a1, %minus_one) # new size: [4,1,2,1]
400
        %b1 : Tensor = aten::unsqueeze(%b, %three)      # new size: [4,3,2,1]
401
        %c1 : Tensor = aten::unsqueeze(%c, %minus_four) # new size: [1,3,2,2]
402
        %ab : Tensor = aten::mul(%a2, %b1)         # expected size: [4,3,2,1]
403
        %abc : Tensor = aten::mul(%ab, %c1)        # expected size: [4,3,2,2]
404
        return (%abc))IR";
405
    auto graph = std::make_shared<Graph>();
406
    parseIR(graph_string, &*graph);
407

408
    auto a = at::rand({4, 2}, TensorOptions(kCPU).dtype(at::kFloat));
409
    auto b = at::rand({4, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
410
    auto c = at::rand({3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
411
    auto o = at::zeros({4, 3, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
412
    auto ref = at::unsqueeze(at::unsqueeze(a, 1), -1) * at::unsqueeze(b, 3) *
413
        at::unsqueeze(c, -4);
414

415
    TensorExprKernel k(graph);
416
    std::vector<at::Tensor> inputs = {a, b, c};
417
    StmtPtr s = k.getCodeGenStmt();
418

419
    std::ostringstream oss;
420
    oss << *s;
421

422
    // Check the IR we produced
423
    const std::string& verification_pattern =
424
        R"IR(
425
# CHECK: for
426
# CHECK-NEXT: for
427
# CHECK-NEXT: for
428
# CHECK-NEXT: for
429
# CHECK-NEXT: aten_mul)IR";
430
    torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
431

432
    std::vector<IValue> stack = fmap<IValue>(inputs);
433
    k.run(stack);
434
    o = stack[0].toTensor();
435

436
    // Check sizes
437
    TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
438
    size_t num_el = 1;
439
    for (const auto idx : c10::irange(ref.sizes().size())) {
440
      TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
441
      num_el *= ref.sizes()[idx];
442
    }
443

444
    // Check the contents
445
    for (const auto i : c10::irange(num_el)) {
446
      TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
447
    }
448
  }
449
  {
450
    // Test that shape inference handles aten::cat
451

452
    const auto graph_string = R"IR(
453
      graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
454
            %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
455
            %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)):
456
        %dim : int = prim::Constant[value=1]()
457
        %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
458
        %r : Tensor = aten::cat(%inputs, %dim)               # new size: [5,19,2]
459
        return (%r))IR";
460
    auto graph = std::make_shared<Graph>();
461
    parseIR(graph_string, &*graph);
462

463
    auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
464
    auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
465
    auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat));
466
    auto o = at::zeros({5, 19, 2}, TensorOptions(kCPU).dtype(at::kFloat));
467
    auto ref = at::cat({a, b, c}, 1);
468

469
    TensorExprKernel k(graph);
470
    std::vector<at::Tensor> inputs = {a, b, c};
471
    StmtPtr s = k.getCodeGenStmt();
472

473
    std::ostringstream oss;
474
    oss << *s;
475

476
    // Check the IR we produced
477
    const std::string& verification_pattern =
478
        R"IR(
479
# CHECK: for
480
# CHECK-NEXT: for
481
# CHECK-NEXT: for
482
# CHECK-NEXT: aten_cat)IR";
483
    torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
484

485
    std::vector<IValue> stack = fmap<IValue>(inputs);
486
    k.run(stack);
487
    o = stack[0].toTensor();
488

489
    // Check sizes
490
    TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
491
    size_t num_el = 1;
492
    for (const auto idx : c10::irange(ref.sizes().size())) {
493
      TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
494
      num_el *= ref.sizes()[idx];
495
    }
496

497
    // Check the contents
498
    for (const auto i : c10::irange(num_el)) {
499
      TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
500
    }
501
  }
502
  {
503
    // Test that we throw an error when input list for aten::cat is empty
504

505
    const auto graph_string = R"IR(
506
      graph():
507
        %dim : int = prim::Constant[value=1]()
508
        %inputs : Tensor[] = prim::ListConstruct()
509
        %r : Tensor = aten::cat(%inputs, %dim)
510
        return (%r))IR";
511
    auto graph = std::make_shared<Graph>();
512
    parseIR(graph_string, &*graph);
513
    auto compile = [&]() {
514
      TensorExprKernel k(graph);
515
      k.getCodeGenStmt();
516
    };
517
    ASSERT_THROWS_WITH(compile(), "Empty input list is passed to aten::cat");
518
  }
519
  {
520
    // Test that we throw an error when 'dim' passed to aten::cat is invalid
521

522
    const auto ir_dim_99 = R"IR(
523
      graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
524
            %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)):
525
        %dim : int = prim::Constant[value=99]()
526
        %inputs : Tensor[] = prim::ListConstruct(%a, %b)
527
        %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim)
528
        return (%r))IR";
529
    const auto ir_dim_minus_6 = R"IR(
530
      graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
531
            %b : Float(5, 3, 2, strides=[6, 2, 1], device=cpu)):
532
        %dim : int = prim::Constant[value=-6]()
533
        %inputs : Tensor[] = prim::ListConstruct(%a, %b)
534
        %r : Float(5, 3, 2, strides=[6, 2, 1], device=cpu) = aten::cat(%inputs, %dim)
535
        return (%r))IR";
536

537
    auto compile = [](const std::string& graph_string) {
538
      auto graph = std::make_shared<Graph>();
539
      parseIR(graph_string, &*graph);
540
      TensorExprKernel k(graph);
541
      k.getCodeGenStmt();
542
    };
543
    ASSERT_THROWS_WITH(compile(ir_dim_99), "Invalid index");
544
    ASSERT_THROWS_WITH(compile(ir_dim_minus_6), "Invalid index");
545
  }
546
}
547

548
TEST_F(Kernel, CatInputTypesPromotion) {
549
  {
550
    // Test that we properly promote input types for aten::cat
551

552
    const auto graph_string = R"IR(
553
      graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
554
            %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
555
            %c : Double(5, 9, 2, strides=[18, 2, 1], device=cpu)):
556
        %dim : int = prim::Constant[value=1]()
557
        %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
558
        %r : Double(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim)
559
        return (%r))IR";
560
    auto graph = std::make_shared<Graph>();
561
    parseIR(graph_string, &*graph);
562

563
    auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
564
    auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
565
    auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kDouble));
566
    auto ref = at::cat({a, b, c}, 1);
567

568
    TensorExprKernel k(graph);
569
    std::vector<at::Tensor> inputs = {a, b, c};
570
    StmtPtr s = k.getCodeGenStmt();
571

572
    std::ostringstream oss;
573
    oss << *s;
574

575
    // Check the IR we produced
576
    const std::string& verification_pattern =
577
        R"IR(
578
# CHECK: for
579
# CHECK-NEXT: for
580
# CHECK-NEXT: for
581
# CHECK-NEXT: aten_cat)IR";
582
    torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
583

584
    std::vector<IValue> stack = fmap<IValue>(inputs);
585
    k.run(stack);
586
    auto o = stack[0].toTensor();
587

588
    // Check sizes
589
    TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
590
    TORCH_CHECK_EQ(o.dtype(), ref.dtype());
591
    size_t num_el = 1;
592
    for (const auto idx : c10::irange(ref.sizes().size())) {
593
      TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
594
      num_el *= ref.sizes()[idx];
595
    }
596

597
    // Check the contents
598
    for (const auto i : c10::irange(num_el)) {
599
      TORCH_CHECK_EQ(((double*)o.data_ptr())[i], ((double*)ref.data_ptr())[i]);
600
    }
601
  }
602
}
603

604
TEST_F(Kernel, ToDType) {
605
#ifdef TORCH_ENABLE_LLVM
606
  const auto graph_string = R"IR(
607
      graph(%x.1 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
608
        %1 : NoneType = prim::Constant()
609
        %2 : bool = prim::Constant[value=0]()
610
        %3 : int = prim::Constant[value=6]()
611
        %4 : int = prim::Constant[value=15]()
612
        %5 : int = prim::Constant[value=5]()
613
        %6 : bool = prim::Constant[value=1]()
614
        %y.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::sigmoid(%x.1)
615
        %z.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_reduced_precision(%y.3, %6, %6, %5, %4)
616
        %h.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::_autocast_to_full_precision(%z.3, %6, %6)
617
        %i.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%h.3, %3, %2, %2, %1)
618
        %j.3 : BFloat16(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%i.3, %4, %2, %2, %1)
619
        %k.3 : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::to(%j.3, %3, %2, %2, %1)
620
        return (%k.3))IR";
621

622
  auto graph = std::make_shared<Graph>();
623
  parseIR(graph_string, &*graph);
624
  TensorExprKernel k(graph);
625
  StmtPtr s = k.getCodeGenStmt();
626
  std::ostringstream oss;
627
  oss << *s;
628

629
  const std::string& verification_pattern =
630
      R"IR(
631
# CHECK: for
632
# CHECK-NEXT: for
633
# CHECK-NEXT: aten_to
634
# CHECK-NEXT: }
635
# CHECK-NEXT: })IR";
636
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
637

638
  auto a = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kBFloat16));
639
  auto ref =
640
      at::_to_copy(at::sigmoid(a), TensorOptions(kCPU).dtype(at::kFloat));
641

642
  std::vector<at::Tensor> inputs = {a};
643
  std::vector<IValue> stack = fmap<IValue>(inputs);
644
  k.run(stack);
645
  auto o = stack[0].toTensor();
646
  ASSERT_EQ(o.sizes(), ref.sizes());
647
  ASSERT_EQ(o.dtype(), ref.dtype());
648
  ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3));
649
#endif
650
}
651

652
TEST_F(Kernel, CatAndInlineWithAConstantDim) {
653
  const auto graph_string = R"IR(
654
      graph(%0 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu),
655
            %1 : Float(1, 512, strides=[1024, 1], requires_grad=0, device=cpu)):
656
        %2 : bool = prim::Constant[value=0]()
657
        %3 : int = prim::Constant[value=1]()
658
        %4 : Tensor[] = prim::ListConstruct(%0, %1)
659
        %5 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%4, %3)
660
        %6 : Tensor[] = prim::ListConstruct(%5)
661
        %7 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::cat(%6, %3)
662
        %8 : Float(1, 1024, strides=[1024, 1], requires_grad=0, device=cpu) = aten::_cast_Float(%7, %2)
663
        return (%8, %7))IR";
664

665
  auto graph = std::make_shared<Graph>();
666
  parseIR(graph_string, &*graph);
667
  TensorExprKernel k(graph);
668

669
  auto a = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
670
  auto b = at::rand({1, 512}, TensorOptions(kCPU).dtype(at::kFloat));
671
  auto ref = at::_cast_Float(at::cat({a, b}, 1), 0);
672

673
  std::vector<at::Tensor> inputs = {a, b};
674
  std::vector<IValue> stack = fmap<IValue>(inputs);
675
  k.run(stack);
676
  auto o = stack[0].toTensor();
677
  ASSERT_EQ(o.sizes(), ref.sizes());
678
  ASSERT_EQ(o.dtype(), ref.dtype());
679
  ASSERT_TRUE(at::allclose(o, ref));
680
}
681

682
TEST_F(Kernel, CatWithEmptyInputs) {
683
  bool curr_cat_wo_conditionals = getCatWoConditionals();
684
  for (auto cat_wo_conditionals : {true, false}) {
685
    getCatWoConditionals() = cat_wo_conditionals;
686
    const auto graph_string = R"IR(
687
        graph(%0 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu),
688
              %1 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu)):
689
          %3 : int = prim::Constant[value=0]()
690
          %6 : Float(0, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%0)
691
          %7 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::tanh(%1)
692
          %10 : Tensor[] = prim::ListConstruct(%6, %7)
693
          %11 : Float(10, 64, strides=[64, 1], requires_grad=0, device=cpu) = aten::cat(%10, %3)
694
          return (%11))IR";
695

696
    auto graph = std::make_shared<Graph>();
697
    parseIR(graph_string, &*graph);
698
    TensorExprKernel k(graph);
699

700
    auto a = at::rand({0, 64}, TensorOptions(kCPU).dtype(at::kFloat));
701
    auto b = at::rand({10, 64}, TensorOptions(kCPU).dtype(at::kFloat));
702
    auto ref = at::cat({at::tanh(a), at::tanh(b)}, 0);
703

704
    std::vector<at::Tensor> inputs = {a, b};
705
    std::vector<IValue> stack = fmap<IValue>(inputs);
706
    k.run(stack);
707
    auto o = stack[0].toTensor();
708
    ASSERT_EQ(o.sizes(), ref.sizes());
709
    ASSERT_EQ(o.dtype(), ref.dtype());
710
    ASSERT_TRUE(at::allclose(o, ref));
711
  }
712
  getCatWoConditionals() = curr_cat_wo_conditionals;
713
}
714

715
TEST_F(Kernel, CatWoConditionals) {
716
  bool old_cat_wo_conditionals = getCatWoConditionals();
717
  getCatWoConditionals() = true;
718
  const auto graph_string = R"IR(
719
      graph(%a : Float(5, 3, 2, strides=[6, 2, 1], device=cpu),
720
            %b : Float(5, 7, 2, strides=[14, 2, 1], device=cpu),
721
            %c : Float(5, 9, 2, strides=[18, 2, 1], device=cpu)):
722
        %dim : int = prim::Constant[value=1]()
723
        %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
724
        %r : Float(5, 19, 2, strides=[38, 2, 1]) = aten::cat(%inputs, %dim)
725
        return (%r))IR";
726

727
  auto graph = std::make_shared<Graph>();
728
  parseIR(graph_string, &*graph);
729

730
  TensorExprKernel k(graph);
731
  StmtPtr s = k.getCodeGenStmt();
732
  std::ostringstream oss;
733
  oss << *s;
734

735
  const std::string& verification_pattern =
736
      R"IR(
737
# CHECK: for
738
# CHECK: for
739
# CHECK: for
740
# CHECK: aten_cat
741
# CHECK: for
742
# CHECK: for
743
# CHECK: aten_cat
744
# CHECK: for
745
# CHECK: for
746
# CHECK: aten_cat)IR";
747
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
748

749
  auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
750
  auto b = at::rand({5, 7, 2}, TensorOptions(kCPU).dtype(at::kFloat));
751
  auto c = at::rand({5, 9, 2}, TensorOptions(kCPU).dtype(at::kFloat));
752
  auto ref = at::cat({a, b, c}, 1);
753

754
  std::vector<at::Tensor> inputs = {a, b, c};
755
  std::vector<IValue> stack = fmap<IValue>(inputs);
756
  k.run(stack);
757
  auto o = stack[0].toTensor();
758

759
  // Check sizes
760
  TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
761
  TORCH_CHECK_EQ(o.dtype(), ref.dtype());
762
  size_t num_el = 1;
763
  for (const auto idx : c10::irange(ref.sizes().size())) {
764
    TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
765
    num_el *= ref.sizes()[idx];
766
  }
767

768
  // Check the contents
769
  for (const auto i : c10::irange(num_el)) {
770
    TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
771
  }
772
  getCatWoConditionals() = old_cat_wo_conditionals;
773
}
774

775
TEST_F(Kernel, OptimizeConditionals) {
776
  bool old_cat_wo_conditionals = getCatWoConditionals();
777
  bool old_opt_conditionals = getOptConditionals();
778
  getCatWoConditionals() = false;
779
  getOptConditionals() = true;
780
  const auto graph_string = R"IR(
781
      graph(%a : Float(5, 3, strides=[3, 1], device=cpu),
782
            %b : Float(5, 7, strides=[7, 1], device=cpu),
783
            %c : Float(5, 9, strides=[9, 1], device=cpu)):
784
        %dim : int = prim::Constant[value=1]()
785
        %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
786
        %r : Float(5, 19, strides=[19, 1]) = aten::cat(%inputs, %dim)
787
        %t : Float(5, 19, strides=[19, 1]) = aten::relu(%r)
788
        return (%t))IR";
789

790
  auto graph = std::make_shared<Graph>();
791
  parseIR(graph_string, &*graph);
792

793
  TensorExprKernel k(graph);
794
  StmtPtr s = k.getCodeGenStmt();
795
  std::ostringstream oss;
796
  oss << *s;
797

798
  const std::string& verification_pattern =
799
      R"IR(
800
# CHECK: for
801
# CHECK-NEXT: for
802
# CHECK-NEXT: aten_relu
803
# CHECK: for
804
# CHECK-NEXT: aten_relu
805
# CHECK: for
806
# CHECK-NEXT: aten_relu
807
# CHECK-NOT: Allocate
808
# CHECK-NOT: Free)IR";
809
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
810

811
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
812
  auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
813
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
814
  auto b = at::rand({5, 7}, TensorOptions(kCPU).dtype(at::kFloat));
815
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
816
  auto c = at::rand({5, 9}, TensorOptions(kCPU).dtype(at::kFloat));
817
  auto ref = at::relu(at::cat({a, b, c}, 1));
818

819
  std::vector<at::Tensor> inputs = {a, b, c};
820
  std::vector<IValue> stack = fmap<IValue>(inputs);
821
  k.run(stack);
822
  auto o = stack[0].toTensor();
823

824
  // Check sizes
825
  TORCH_CHECK_EQ(o.sizes().size(), ref.sizes().size());
826
  TORCH_CHECK_EQ(o.dtype(), ref.dtype());
827
  size_t num_el = 1;
828
  for (const auto idx : c10::irange(ref.sizes().size())) {
829
    TORCH_CHECK_EQ(o.sizes()[idx], ref.sizes()[idx]);
830
    num_el *= ref.sizes()[idx];
831
  }
832

833
  // Check the contents
834
  for (const auto i : c10::irange(num_el)) {
835
    TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
836
  }
837
  getOptConditionals() = old_opt_conditionals;
838
  getCatWoConditionals() = old_cat_wo_conditionals;
839
}
840

841
namespace {
842

843
std::string dtypeConstant(ScalarType scalar_type) {
844
  if (scalar_type == ScalarType::Undefined) {
845
    return "None = prim::Constant()";
846
  } else {
847
    at::jit::TemplateEnv env_dtype;
848
    env_dtype.d("scalar_type", static_cast<int>(scalar_type));
849
    return format("int = prim::Constant[value=${scalar_type}]()", env_dtype);
850
  }
851
}
852

853
at::Tensor iotaTensor(IntArrayRef sizes, const at::TensorOptions& options) {
854
  int64_t numel = std::accumulate(
855
      sizes.begin(),
856
      sizes.end(),
857
      1,
858
      // NOLINTNEXTLINE(modernize-use-transparent-functors)
859
      std::multiplies<int64_t>());
860
  std::vector<float> values(numel);
861
  std::iota(values.begin(), values.end(), 0);
862
  auto a = at::tensor(values, options);
863
  return a.reshape(sizes);
864
}
865

866
} // namespace
867

868
TEST_F(Kernel, SumAllAxes) {
869
  // Test lowering of sum on all axes.
870
  const auto graph_template = R"IR(
871
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
872
        %1 : ${dtype}
873
        %2 : ${out_dtype}(requires_grad=0, device=cpu) = aten::sum(%0, %1)
874
        return (%2))IR";
875
  auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
876

877
  for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) {
878
    at::jit::TemplateEnv env;
879
    env.s("dtype", dtypeConstant(scalar_type));
880
    if (scalar_type == ScalarType::Undefined) {
881
      env.s("out_dtype", "Float");
882
    } else {
883
      env.s("out_dtype", "Double");
884
    }
885
    const auto graph_string = format(graph_template, env);
886

887
    auto graph = std::make_shared<Graph>();
888
    parseIR(graph_string, &*graph);
889

890
    auto o = at::empty({}, TensorOptions(kCPU));
891
    std::optional<c10::ScalarType> dtype;
892
    if (scalar_type != ScalarType::Undefined) {
893
      dtype = static_cast<c10::ScalarType>(scalar_type);
894
    }
895
    auto ref = a.sum(/*dtype=*/dtype);
896
    TensorExprKernel k(graph);
897
    std::vector<at::Tensor> inputs = {a};
898
    StmtPtr s = k.getCodeGenStmt();
899

900
    std::ostringstream oss;
901
    oss << *s;
902

903
    // Check the IR we produced
904
    const std::string& verification_pattern =
905
        R"IR(
906
# CHECK: for
907
# CHECK-NEXT: for)IR";
908
    torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
909

910
    std::vector<IValue> stack = fmap<IValue>(inputs);
911
    k.run(stack);
912
    o = stack[0].toTensor();
913
    ASSERT_EQ(o.sizes(), ref.sizes());
914
    ASSERT_EQ(o.dtype(), ref.dtype());
915
    ASSERT_TRUE(at::allclose(o, ref));
916
  }
917
}
918

919
std::string li_to_str(at::ArrayRef<int64_t> li) {
920
  std::stringstream out;
921
  bool first = true;
922
  for (auto elem : li) {
923
    if (!first) {
924
      out << ", ";
925
    }
926
    out << elem;
927
    first = false;
928
  }
929
  return out.str();
930
}
931

932
TEST_F(Kernel, SumOneAxis) {
933
  // Test lowering of sum on one axis.
934
  const auto graph_template = R"IR(
935
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
936
        %1 : int[] = prim::Constant[value=[${dim}]]()
937
        %2 : bool = prim::Constant[value=${keepdim}]()
938
        %3 : ${dtype}
939
        %4 : ${out_dtype}(${size}, strides=[${strides}], device=cpu) = aten::sum(%0, %1, %2, %3)
940
        return (%4))IR";
941
  auto a = iotaTensor({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
942

943
  for (int dim = -a.dim(); dim < a.dim(); ++dim) {
944
    for (bool keepdim : {false, true}) {
945
      for (auto scalar_type : {ScalarType::Undefined, ScalarType::Double}) {
946
        at::jit::TemplateEnv env;
947
        env.d("dim", dim);
948
        env.d("keepdim", keepdim);
949
        env.s("dtype", dtypeConstant(scalar_type));
950
        std::optional<c10::ScalarType> dtype;
951
        if (scalar_type != ScalarType::Undefined) {
952
          dtype = static_cast<c10::ScalarType>(scalar_type);
953
        }
954
        auto ref = a.sum({dim}, /*keepdim=*/keepdim, /*dtype=*/dtype);
955
        if (scalar_type == ScalarType::Undefined) {
956
          env.s("out_dtype", "Float");
957
        } else {
958
          env.s("out_dtype", "Double");
959
        }
960
        env.s("size", li_to_str(ref.sizes()));
961
        env.s("strides", li_to_str(ref.strides()));
962
        const auto graph_string = format(graph_template, env);
963
        auto graph = std::make_shared<Graph>();
964
        parseIR(graph_string, &*graph);
965

966
        auto o = at::empty({}, TensorOptions(kCPU));
967
        TensorExprKernel k(graph);
968
        std::vector<at::Tensor> inputs = {a};
969
        StmtPtr s = k.getCodeGenStmt();
970

971
        std::ostringstream oss;
972
        oss << *s;
973

974
        // Check the IR we produced
975
        const std::string& verification_pattern =
976
            R"IR(
977
# CHECK: for (int64_t
978
# CHECK-NEXT: sum
979
# CHECK-NEXT: for (int64_t
980
# CHECK-NEXT:   sum)IR";
981
        torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
982

983
        std::vector<IValue> stack = fmap<IValue>(inputs);
984
        k.run(stack);
985
        o = stack[0].toTensor();
986
        ASSERT_EQ(o.sizes(), ref.sizes());
987
        ASSERT_EQ(o.dtype(), ref.dtype());
988
        ASSERT_TRUE(at::allclose(o, ref, 4E-3, 4E-3));
989
      }
990
    }
991
  }
992
}
993

994
TEST_F(Kernel, SumMultipleAxes) {
995
  // Test lowering of sum on multiple axes.
996
  const auto graph_template = R"IR(
997
      graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], requires_grad=0, device=cpu)):
998
        %1 : int = prim::Constant[value=${dim1}]()
999
        %2 : int = prim::Constant[value=${dim2}]()
1000
        %3 : int[] = prim::ListConstruct(%1, %2)
1001
        %4 : bool = prim::Constant[value=${keepdim}]()
1002
        %5 : ${dtype}
1003
        %6 : Float(${size}, strides=[${strides}], requires_grad=0, device=cpu) = aten::sum(%0, %3, %4, %5)
1004
        return (%6))IR";
1005
  auto a = iotaTensor({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1006

1007
  // Only iterate over positive values of axes to keep the running time
1008
  // reasonable, since the number of pairs is quadratic.
1009
  for (const auto dim1 : c10::irange(a.dim())) {
1010
    for (int dim2 = dim1 + 1; dim2 < a.dim(); ++dim2) {
1011
      for (bool keepdim : {false, true}) {
1012
        at::jit::TemplateEnv env;
1013
        env.d("dim1", dim1);
1014
        env.d("dim2", dim2);
1015
        env.d("keepdim", keepdim);
1016
        env.s("dtype", dtypeConstant(ScalarType::Undefined));
1017
        auto o = at::empty({}, TensorOptions(kCPU));
1018
        auto ref = a.sum(IntArrayRef{dim1, dim2}, /*keepdim=*/keepdim);
1019

1020
        env.s("size", li_to_str(ref.sizes()));
1021
        env.s("strides", li_to_str(ref.strides()));
1022

1023
        const auto graph_string = format(graph_template, env);
1024

1025
        auto graph = std::make_shared<Graph>();
1026
        parseIR(graph_string, &*graph);
1027

1028
        TensorExprKernel k(graph);
1029
        std::vector<at::Tensor> inputs = {a};
1030
        StmtPtr s = k.getCodeGenStmt();
1031

1032
        std::ostringstream oss;
1033
        oss << *s;
1034

1035
        // Check the IR we produced
1036
        const std::string& verification_pattern =
1037
            R"IR(
1038
# CHECK: for (int64_t
1039
# CHECK: for (int64_t
1040
# CHECK: for (int64_t
1041
# CHECK: for (int64_t
1042
# CHECK: sum)IR";
1043
        torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1044

1045
        std::vector<IValue> stack = fmap<IValue>(inputs);
1046
        k.run(stack);
1047
        o = stack[0].toTensor();
1048
        ASSERT_EQ(o.sizes(), ref.sizes());
1049
        ASSERT_EQ(o.dtype(), ref.dtype());
1050
        ASSERT_TRUE(at::allclose(o, ref));
1051
      }
1052
    }
1053
  }
1054
}
1055

1056
// This test and the following ones testing Softmax only tests with dim set
1057
// to one of the valid input dimensions. It does not test with dim=None
1058
// because that is supposed to be deprecated.
1059
TEST_F(Kernel, Softmax2D) {
1060
  const auto graph_template = R"IR(
1061
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu)):
1062
        %1 : int = prim::Constant[value=${dim}]()
1063
        %dt_float : int = prim::Constant[value=7]()
1064
        %dt_none : NoneType = prim::Constant()
1065
        %4 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %${dt})
1066
        return (%4))IR";
1067

1068
  auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1069

1070
  const std::string& verification_template =
1071
      R"IR(
1072
        # CHECK: for (int i${other_dim} = 0; i${other_dim} < ${other_dim_size}
1073
        # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1074
        # CHECK-NEXT: aten_softmax_max
1075
        # CHECK: for (int i${other_dim}_1 = 0; i${other_dim}_1 < ${other_dim_size}
1076
        # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1077
        # CHECK-NEXT: aten_softmax_sum
1078
        # CHECK: for (int i0_2 = 0; i0_2 < 5
1079
        # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3
1080
        # CHECK-NEXT: aten_softmax)IR";
1081

1082
  for (bool empty_dtype : {false, true}) {
1083
    for (auto log_softmax : {false, true}) {
1084
      for (const auto softmax_dim : c10::irange(a.dim())) {
1085
        auto softmax_dim_size = a.sizes()[softmax_dim];
1086
        auto other_dim = (softmax_dim + 1) % a.dim();
1087
        auto ref =
1088
            log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1089
        at::jit::TemplateEnv env;
1090
        env.d("dim", softmax_dim);
1091
        env.s("op", log_softmax ? "log_softmax" : "softmax");
1092
        env.s("size", li_to_str(ref.sizes()));
1093
        env.s("strides", li_to_str(ref.strides()));
1094
        env.s("dt", empty_dtype ? "dt_none" : "dt_float");
1095

1096
        const auto graph_string = format(graph_template, env);
1097

1098
        auto graph = std::make_shared<Graph>();
1099
        parseIR(graph_string, &*graph);
1100

1101
        TensorExprKernel k(graph);
1102
        std::vector<at::Tensor> inputs = {a};
1103
        StmtPtr s = k.getCodeGenStmt();
1104

1105
        std::ostringstream oss;
1106
        oss << *s;
1107

1108
        at::jit::TemplateEnv ver_env;
1109
        ver_env.d("other_dim", other_dim);
1110
        ver_env.d("other_dim_size", a.sizes()[other_dim]);
1111
        ver_env.d("softmax_dim", softmax_dim);
1112
        ver_env.d("softmax_dim_size", softmax_dim_size);
1113
        const auto verification_pattern =
1114
            format(verification_template, ver_env);
1115

1116
        // verification sting temporarily disabled until
1117
        // inlining of exp() is benchmarked and determined
1118
        // torch::jit::testing::FileCheck().run(verification_pattern,
1119
        // oss.str());
1120

1121
        std::vector<IValue> stack = fmap<IValue>(inputs);
1122
        k.run(stack);
1123
        auto output = stack[0].toTensor();
1124
        ASSERT_EQ(output.sizes(), ref.sizes());
1125
        ASSERT_TRUE(at::allclose(output, ref));
1126
      }
1127
    }
1128
  }
1129
}
1130

1131
TEST_F(Kernel, Softmax3D) {
1132
  const auto graph_template = R"IR(
1133
      graph(%0 : Float(3, 4, 5, strides=[20, 5, 1], device=cpu)):
1134
        %1 : int = prim::Constant[value=${dim}]()
1135
        %2 : int = prim::Constant[value=7]()
1136
        %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2)
1137
        return (%3))IR";
1138

1139
  auto a = at::rand({3, 4, 5}, TensorOptions(kCPU).dtype(at::kFloat));
1140

1141
  const std::string& verification_template =
1142
      R"IR(
1143
        # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size}
1144
        # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size}
1145
        # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1146
        # CHECK-NEXT: aten_softmax_max
1147
        # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size}
1148
        # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size}
1149
        # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1150
        # CHECK-NEXT: aten_softmax_sum
1151
        # CHECK: for (int i0_2 = 0; i0_2 < 3
1152
        # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 4
1153
        # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 5
1154
        # CHECK-NEXT: aten_softmax)IR";
1155

1156
  for (auto log_softmax : {false, true}) {
1157
    for (const auto softmax_dim : c10::irange(a.dim())) {
1158
      auto softmax_dim_size = a.sizes()[softmax_dim];
1159
      std::vector<int> other_dims;
1160
      for (const auto i : c10::irange(a.dim())) {
1161
        if (i != softmax_dim) {
1162
          other_dims.push_back(i);
1163
        }
1164
      }
1165
      auto ref =
1166
          log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1167

1168
      at::jit::TemplateEnv env;
1169
      env.d("dim", softmax_dim);
1170
      env.s("op", log_softmax ? "log_softmax" : "softmax");
1171
      env.s("size", li_to_str(ref.sizes()));
1172
      env.s("strides", li_to_str(ref.strides()));
1173

1174
      const auto graph_string = format(graph_template, env);
1175

1176
      auto graph = std::make_shared<Graph>();
1177
      parseIR(graph_string, &*graph);
1178

1179
      TensorExprKernel k(graph);
1180
      std::vector<at::Tensor> inputs = {a};
1181
      StmtPtr s = k.getCodeGenStmt();
1182

1183
      std::ostringstream oss;
1184
      oss << *s;
1185

1186
      at::jit::TemplateEnv ver_env;
1187
      ver_env.d("dim1", other_dims[0]);
1188
      ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
1189
      ver_env.d("dim2", other_dims[1]);
1190
      ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
1191
      ver_env.d("softmax_dim", softmax_dim);
1192
      ver_env.d("softmax_dim_size", softmax_dim_size);
1193
      const auto verification_pattern = format(verification_template, ver_env);
1194

1195
      // verification sting temporarily disabled until
1196
      // inlining of exp() is benchmarked and determined
1197
      // torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1198

1199
      std::vector<IValue> stack = fmap<IValue>(inputs);
1200
      k.run(stack);
1201
      auto output = stack[0].toTensor();
1202

1203
      ASSERT_EQ(output.sizes(), ref.sizes());
1204
      ASSERT_TRUE(at::allclose(output, ref));
1205
    }
1206
  }
1207
}
1208

1209
TEST_F(Kernel, Softmax4D) {
1210
  const auto graph_template = R"IR(
1211
      graph(%0 : Float(2, 3, 2, 3, strides=[18, 6, 3, 1], device=cpu)):
1212
        %1 : int = prim::Constant[value=${dim}]()
1213
        %2 : int = prim::Constant[value=7]()
1214
        %3 : Float(${size}, strides=[${strides}]) = aten::${op}(%0, %1, %2)
1215
        return (%3))IR";
1216

1217
  auto a = at::rand({2, 3, 2, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1218

1219
  const std::string& verification_template =
1220
      R"IR(
1221
        # CHECK: for (int i${dim1} = 0; i${dim1} < ${dim1_size}
1222
        # CHECK-NEXT: for (int i${dim2} = 0; i${dim2} < ${dim2_size}
1223
        # CHECK-NEXT: for (int i${dim3} = 0; i${dim3} < ${dim3_size}
1224
        # CHECK: for (int i${softmax_dim} = 0; i${softmax_dim} < ${softmax_dim_size}
1225
        # CHECK-NEXT: aten_softmax_max
1226
        # CHECK: for (int i${dim1}_1 = 0; i${dim1}_1 < ${dim1_size}
1227
        # CHECK-NEXT: for (int i${dim2}_1 = 0; i${dim2}_1 < ${dim2_size}
1228
        # CHECK-NEXT: for (int i${dim3}_1 = 0; i${dim3}_1 < ${dim3_size}
1229
        # CHECK: for (int i${softmax_dim}_1 = 0; i${softmax_dim}_1 < ${softmax_dim_size}
1230
        # CHECK-NEXT: aten_softmax_sum
1231
        # CHECK: for (int i0_2 = 0; i0_2 < 2
1232
        # CHECK-NEXT: for (int i1_2 = 0; i1_2 < 3
1233
        # CHECK-NEXT: for (int i2_2 = 0; i2_2 < 2
1234
        # CHECK-NEXT: for (int i3_2 = 0; i3_2 < 3
1235
        # CHECK-NEXT: aten_softmax)IR";
1236

1237
  for (auto log_softmax : {false, true}) {
1238
    for (const auto softmax_dim : c10::irange(a.dim())) {
1239
      auto softmax_dim_size = a.sizes()[softmax_dim];
1240
      std::vector<int> other_dims;
1241
      for (const auto i : c10::irange(a.dim())) {
1242
        if (i != softmax_dim) {
1243
          other_dims.push_back(i);
1244
        }
1245
      }
1246
      auto ref =
1247
          log_softmax ? a.log_softmax(softmax_dim) : a.softmax(softmax_dim);
1248

1249
      at::jit::TemplateEnv env;
1250
      env.d("dim", softmax_dim);
1251
      env.s("op", log_softmax ? "log_softmax" : "softmax");
1252
      env.s("size", li_to_str(ref.sizes()));
1253
      env.s("strides", li_to_str(ref.strides()));
1254

1255
      const auto graph_string = format(graph_template, env);
1256

1257
      auto graph = std::make_shared<Graph>();
1258
      parseIR(graph_string, &*graph);
1259

1260
      TensorExprKernel k(graph);
1261
      std::vector<at::Tensor> inputs = {a};
1262
      StmtPtr s = k.getCodeGenStmt();
1263

1264
      std::ostringstream oss;
1265
      oss << *s;
1266

1267
      at::jit::TemplateEnv ver_env;
1268
      ver_env.d("dim1", other_dims[0]);
1269
      ver_env.d("dim1_size", a.sizes()[other_dims[0]]);
1270
      ver_env.d("dim2", other_dims[1]);
1271
      ver_env.d("dim2_size", a.sizes()[other_dims[1]]);
1272
      ver_env.d("dim3", other_dims[2]);
1273
      ver_env.d("dim3_size", a.sizes()[other_dims[2]]);
1274
      ver_env.d("softmax_dim", softmax_dim);
1275
      ver_env.d("softmax_dim_size", softmax_dim_size);
1276
      const auto verification_pattern = format(verification_template, ver_env);
1277

1278
      // verification sting temporarily disabled until
1279
      // inlining of exp() is benchmarked and determined
1280
      // torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1281

1282
      std::vector<IValue> stack = fmap<IValue>(inputs);
1283
      k.run(stack);
1284
      auto output = stack[0].toTensor();
1285
      ASSERT_EQ(output.sizes(), ref.sizes());
1286
      ASSERT_TRUE(at::allclose(output, ref));
1287
    }
1288
  }
1289
}
1290

1291
TEST_F(Kernel, SignTest) {
1292
  const auto graph_template = R"IR(
1293
      graph(%0 : ${dtype}(${size}, strides=[1], device=cpu)):
1294
        %2 : ${dtype}(${size}, strides=[1]) = aten::sign(%0)
1295
        return (%2))IR";
1296

1297
  auto run_test = [](const std::string& graph_string, const at::Tensor& input) {
1298
    auto graph = std::make_shared<Graph>();
1299
    parseIR(graph_string, &*graph);
1300

1301
    TensorExprKernel k(graph);
1302
    StmtPtr s = k.getCodeGenStmt();
1303

1304
    std::vector<at::Tensor> inputs = {input};
1305
    std::vector<IValue> stack = fmap<IValue>(inputs);
1306
    k.run(stack);
1307
    auto o = stack[0].toTensor();
1308
    auto ref = at::sign(input);
1309
    ASSERT_TRUE(at::allclose(o, ref));
1310
  };
1311
  auto common_options = at::TensorOptions()
1312
                            .layout(at::kStrided)
1313
                            .device(at::kCPU)
1314
                            .requires_grad(false);
1315
  int default_input_size = 100;
1316
  for (auto scalar_type : {ScalarType::Float, ScalarType::Double}) {
1317
    at::Tensor corner_case_inputs;
1318
    at::jit::TemplateEnv env;
1319
    auto options = common_options;
1320
    switch (scalar_type) {
1321
      case ScalarType::Float: {
1322
        env.s("dtype", "Float");
1323
        options = options.dtype(at::kFloat);
1324
        std::vector<float> input_float = {
1325
            0.0f,
1326
            -0.0f,
1327
            std::numeric_limits<float>::infinity(),
1328
            -std::numeric_limits<float>::infinity(),
1329
            std::nanf("1"),
1330
            -std::nanf("1")};
1331
        corner_case_inputs = at::from_blob(
1332
            input_float.data(),
1333
            {static_cast<long>(input_float.size())},
1334
            options);
1335
        auto rand_input = at::rand({default_input_size}, options);
1336
        auto input = at::cat({rand_input, corner_case_inputs});
1337
        env.d("size", at::numel(input));
1338
        const auto graph_string = format(graph_template, env);
1339
        run_test(graph_string, input);
1340
        break;
1341
      }
1342
      case ScalarType::Double: {
1343
        env.s("dtype", "Double");
1344
        options = options.dtype(at::kDouble);
1345
        std::vector<double> input_double = {
1346
            0.0,
1347
            -0.0,
1348
            std::numeric_limits<double>::infinity(),
1349
            -std::numeric_limits<double>::infinity(),
1350
            std::nan("1"),
1351
            -std::nan("1")};
1352
        corner_case_inputs = at::from_blob(
1353
            input_double.data(),
1354
            {static_cast<long>(input_double.size())},
1355
            options);
1356
        auto rand_input = at::rand({default_input_size}, options);
1357
        auto input = at::cat({rand_input, corner_case_inputs});
1358
        env.d("size", at::numel(input));
1359
        const auto graph_string = format(graph_template, env);
1360
        run_test(graph_string, input);
1361
        break;
1362
      }
1363
      default:
1364
        throw unsupported_dtype();
1365
    }
1366
  }
1367
}
1368

1369
TEST_F(Kernel, InlineProducerIntoReduction) {
1370
  // Inline producer (mul) into reduction (sum).
1371
  const auto graph_string = R"IR(
1372
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1373
            %1 : Float(5, 3, strides=[3, 1], device=cpu)):
1374
        %2 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%0, %1)
1375
        %3 : int = prim::Constant[value=7]()
1376
        %4 : Double(device=cpu) = aten::sum(%2, %3)
1377
        return (%4))IR";
1378
  auto graph = std::make_shared<Graph>();
1379
  parseIR(graph_string, &*graph);
1380

1381
  TensorExprKernel k(graph);
1382
  StmtPtr s = k.getCodeGenStmt();
1383
  std::ostringstream oss;
1384
  oss << *s;
1385

1386
  // Check the IR we produced.
1387
  // We should have only one loop in the end.
1388
  const std::string& verification_pattern =
1389
      R"IR(
1390
        # CHECK: for (int64_t i_1 = 0ll; i_1 < 5
1391
        # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3
1392
        # CHECK-NEXT:   sum
1393
        # CHECK-NOT: for)IR";
1394
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1395

1396
  auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1397
  auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1398
  std::vector<at::Tensor> inputs = {a, b};
1399
  std::vector<IValue> stack = fmap<IValue>(inputs);
1400
  k.run(stack);
1401
  auto o = stack[0].toTensor();
1402
  auto ref = (a * b).sum(at::kDouble);
1403
  ASSERT_TRUE(at::allclose(o, ref));
1404
}
1405

1406
TEST_F(Kernel, InlineReductionIntoConsumer) {
1407
  // Inline producer (mul %2) into reduction (sum %4) but DO NOT
1408
  // inline the reduction into consumer (mul %4).
1409
  const auto graph_string = R"IR(
1410
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1411
            %1 : Float(5, 3, strides=[3, 1], device=cpu)):
1412
        %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1413
        %3 : int = prim::Constant[value=6]()
1414
        %4 : Float(device=cpu) = aten::sum(%2, %3)
1415
        %5 : Float(5, 3, strides=[3, 1], device=cpu) = aten::mul(%2, %4)
1416
        return (%5))IR";
1417
  auto graph = std::make_shared<Graph>();
1418
  parseIR(graph_string, &*graph);
1419

1420
  TensorExprKernel k(graph);
1421
  StmtPtr s = k.getCodeGenStmt();
1422
  std::ostringstream oss;
1423
  oss << *s;
1424

1425
  // Check the IR we produced.
1426
  // We should have two loops in the end.
1427
  const std::string& verification_pattern =
1428
      R"IR(
1429
        # CHECK: for (int64_t i_1 = 0ll; i_1 < 5
1430
        # CHECK-NEXT: for (int64_t j_1 = 0ll; j_1 < 3
1431
        # CHECK-NEXT:   sum
1432
        # CHECK: for (int64_t i_2 = 0ll; i_2 < 5
1433
        # CHECK-NEXT: for (int64_t j_2 = 0ll; j_2 < 3
1434
        # CHECK-NEXT:   aten_mul
1435
        # CHECK-NOT: for)IR";
1436
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1437

1438
  auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1439
  auto b = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1440
  std::vector<at::Tensor> inputs = {a, b};
1441
  std::vector<IValue> stack = fmap<IValue>(inputs);
1442
  k.run(stack);
1443
  auto o = stack[0].toTensor();
1444
  auto ref = (a * b).sum(at::kFloat) * (a * b);
1445
  ASSERT_TRUE(at::allclose(o, ref));
1446
}
1447

1448
TEST_F(Kernel, SanitizeNames_CUDA) {
1449
  const auto graph_string = R"IR(
1450
      graph(%0 : Float(5, 3, strides=[3, 1], device=cuda:0),
1451
            %1 : Float(5, 3, strides=[3, 1], device=cuda:0)):
1452
        %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1453
        %4 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1454
        return (%4))IR";
1455
  auto graph = std::make_shared<Graph>();
1456
  parseIR(graph_string, &*graph);
1457
  graph->inputs().at(0)->setDebugName("aten::add:");
1458
  graph->inputs().at(1)->setDebugName("aten::add_");
1459
  TensorExprKernel k(graph);
1460
  auto a = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat));
1461
  auto b = at::rand({5, 3}, TensorOptions(kCUDA).dtype(at::kFloat));
1462
  auto ref = a * (a * b);
1463
  std::vector<at::Tensor> inputs = {a, b};
1464
  std::vector<IValue> stack = fmap<IValue>(inputs);
1465
  k.run(stack);
1466
  auto o = stack[0].toTensor();
1467
  ASSERT_TRUE(at::allclose(o, ref));
1468
}
1469

1470
TEST_F(Kernel, SanitizeConstants_CUDA) {
1471
  const auto graph_string = R"IR(
1472
        graph(%x : Float(16, 16, strides=[16, 1], device=cuda:0)):
1473
          %none : NoneType = prim::Constant()
1474
          %size : int = prim::Constant[value=16]()
1475
          %sizes : int[] = prim::ListConstruct(%size, %size)
1476
          %30 : Device = prim::Constant[value="cuda"]()
1477
          %y : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::ones(%sizes, %none, %none, %30, %none)
1478
          %z : Float(16, 16, strides=[16, 1], device=cuda:0) = aten::mul(%x, %y)
1479
          return (%z))IR";
1480
  auto graph = std::make_shared<Graph>();
1481
  parseIR(graph_string, &*graph);
1482
  // IRParser doesn't support tensor constants, so we insert a call to
1483
  // aten::ones and then const-prop it
1484
  ConstantPropagation(graph);
1485

1486
  // We set the name of the constant to include special characters that are
1487
  // not allowed. This should be fixed by the sanitizer in TensorExprKernel.
1488
  graph->nodes().front()->output()->setDebugName("illegal.name");
1489

1490
  // Check if we have a constant node with illegal name in the graph.
1491
  auto const_node = graph->nodes().front();
1492
  ASSERT_EQ(const_node->kind(), prim::Constant);
1493
  ASSERT_NE(const_node->output()->debugName().find('.'), std::string::npos);
1494

1495
  TensorExprKernel k(graph);
1496

1497
  auto x = at::rand({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat));
1498
  std::vector<at::Tensor> inputs = {x};
1499
  std::vector<IValue> stack = fmap<IValue>(inputs);
1500
  k.run(stack);
1501
  auto o = stack[0].toTensor();
1502
  auto y = at::ones({16, 16}, TensorOptions(kCUDA).dtype(at::kFloat));
1503
  auto ref = x * y;
1504
  ASSERT_TRUE(at::allclose(o, ref));
1505
}
1506

1507
TEST_F(Kernel, ConstantTensors) {
1508
  const auto graph_string = R"IR(
1509
        graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1510
          %none : NoneType = prim::Constant()
1511
          %size : int = prim::Constant[value=16]()
1512
          %sizes : int[] = prim::ListConstruct(%size, %size)
1513
          %y : Float(16, 16, strides=[16, 1], device=cpu) = aten::ones(%sizes, %none, %none, %none, %none)
1514
          %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1515
          return (%z))IR";
1516
  auto graph = std::make_shared<Graph>();
1517
  parseIR(graph_string, &*graph);
1518
  // IRParser doesn't support tensor constants, so we insert a call to
1519
  // aten::ones and then const-prop it
1520
  ConstantPropagation(graph);
1521

1522
  TensorExprKernel k(graph);
1523

1524
  auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1525
  std::vector<at::Tensor> inputs = {x};
1526
  std::vector<IValue> stack = fmap<IValue>(inputs);
1527
  k.run(stack);
1528
  auto o = stack[0].toTensor();
1529
  auto y = at::ones({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1530
  auto ref = x * y;
1531
  ASSERT_TRUE(at::allclose(o, ref));
1532
}
1533

1534
TEST_F(Kernel, ConstantTensorsNonContiguous) {
1535
  const auto graph_string = R"IR(
1536
        graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1537
          %none : NoneType = prim::Constant()
1538
          %dtype : int = prim::Constant[value=6]()
1539
          %c0 : int = prim::Constant[value=0]()
1540
          %c256 : int = prim::Constant[value=256]()
1541
          %c16 : int = prim::Constant[value=16]()
1542
          %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none)
1543
          %sizes : int[] = prim::ListConstruct(%c16, %c16)
1544
          %y_t : Tensor = aten::view(%y_flat, %sizes)
1545
          %y : Tensor = aten::t(%y_t)
1546
          %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1547
          return (%z))IR";
1548
  auto graph = std::make_shared<Graph>();
1549
  parseIR(graph_string, &*graph);
1550
  // IRParser doesn't support tensor constants, so we generate several aten
1551
  // calls to produce non-contiguous constant tensor and then const-prop it
1552
  ConstantPropagation(graph);
1553

1554
  TensorExprKernel k(graph);
1555

1556
  auto x = at::rand({16, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1557
  std::vector<at::Tensor> inputs = {x};
1558
  std::vector<IValue> stack = fmap<IValue>(inputs);
1559
  k.run(stack);
1560
  auto o = stack[0].toTensor();
1561
  auto y = at::arange(0, 256, TensorOptions(kCPU).dtype(at::kFloat))
1562
               .view({16, 16})
1563
               .t();
1564
  auto ref = x * y;
1565
  ASSERT_TRUE(at::allclose(o, ref));
1566
}
1567

1568
TEST_F(Kernel, RunFast) {
1569
#ifdef TORCH_ENABLE_LLVM
1570
  // TODO: Implement call_raw in IREval and remove the ifdef
1571

1572
  const auto graph_string = R"IR(
1573
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1574
            %1 : Float(5, 3, strides=[1, 5], device=cpu)):
1575
        %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1576
        %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1577
        return (%3))IR";
1578
  auto graph = std::make_shared<Graph>();
1579
  parseIR(graph_string, &*graph);
1580

1581
  auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1582
  auto b =
1583
      at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1584
  auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1585
  auto ref = a * (a * b);
1586
  TensorExprKernel k(graph);
1587

1588
  k.runFast({a.data_ptr(), b.data_ptr()}, {o.data_ptr()});
1589
  for (size_t i = 0; i < 5 * 3; i++) {
1590
    TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1591
  }
1592
#endif
1593
}
1594

1595
TEST_F(Kernel, RunWithAllocatedOutputs) {
1596
#ifdef TORCH_ENABLE_LLVM
1597
  const auto graph_string = R"IR(
1598
      graph(%0 : Float(5, 3, strides=[3, 1], device=cpu),
1599
            %1 : Float(5, 3, strides=[1, 5], device=cpu)):
1600
        %2 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %1)
1601
        %3 : Float(5, 3, strides=[3, 1]) = aten::mul(%0, %2)
1602
        return (%3))IR";
1603
  auto graph = std::make_shared<Graph>();
1604
  parseIR(graph_string, &*graph);
1605

1606
  auto a = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1607
  auto b =
1608
      at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1609
  auto o = at::zeros({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1610
  auto ref = a * (a * b);
1611
  TensorExprKernel k(graph);
1612

1613
  std::vector<at::Tensor> args = {o, a, b};
1614
  std::vector<IValue> stack = fmap<IValue>(args);
1615
  k.runWithAllocatedOutputs(stack);
1616
  for (size_t i = 0; i < 5 * 3; i++) {
1617
    TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1618
  }
1619
#endif
1620
}
1621

1622
TEST_F(Kernel, CodegenInspection) {
1623
#ifdef TORCH_ENABLE_LLVM
1624
  const auto graph_string = R"IR(
1625
        graph(%x : Float(16, 16, strides=[16, 1], device=cpu)):
1626
          %none : NoneType = prim::Constant()
1627
          %dtype : int = prim::Constant[value=6]()
1628
          %c0 : int = prim::Constant[value=0]()
1629
          %c256 : int = prim::Constant[value=256]()
1630
          %c16 : int = prim::Constant[value=16]()
1631
          %y_flat : Tensor = aten::arange(%c0, %c256, %dtype, %none, %none, %none)
1632
          %sizes : int[] = prim::ListConstruct(%c16, %c16)
1633
          %y_t : Tensor = aten::view(%y_flat, %sizes)
1634
          %y : Tensor = aten::t(%y_t)
1635
          %z : Float(16, 16, strides=[16, 1], device=cpu) = aten::mul(%x, %y)
1636
          return (%z))IR";
1637
  auto graph = std::make_shared<Graph>();
1638
  parseIR(graph_string, &*graph);
1639
  // IRParser doesn't support tensor constants, so we generate several aten
1640
  // calls to produce non-contiguous constant tensor and then const-prop it
1641
  ConstantPropagation(graph);
1642

1643
  TensorExprKernel k(graph);
1644

1645
  // Check that we could retrieve generated assembly
1646
  auto asm_str = k.getCodeText("asm");
1647
  const std::string& asm_verification_pattern =
1648
      R"ASM(
1649
        # CHECK: .text
1650
        # CHECK: retq)ASM";
1651
  torch::jit::testing::FileCheck().run(asm_verification_pattern, asm_str);
1652

1653
  // Check that we could retrieve info about codegen parameters
1654
  auto constants = k.getConstantDescriptors();
1655
  auto buf_args = k.getBufferArgs();
1656
  // Expected buf args: [input0, output0, constant0]
1657
  ASSERT_EQ(buf_args.size(), 3);
1658
  ASSERT_EQ(constants.size(), 1);
1659
  ASSERT_TRUE(
1660
      !buf_args[0].isVar() && !buf_args[1].isVar() && !buf_args[2].isVar());
1661
#endif
1662
}
1663

1664
Tensor lowerNanToNum(
1665
    const std::vector<ArgValue>& inputs,
1666
    const std::vector<ExprHandle>& outputShape,
1667
    const std::vector<ExprHandle>& outputStrides,
1668
    const std::optional<ScalarType>& outputType,
1669
    at::Device device) {
1670
  auto input_buf = std::get<BufHandle>(inputs[0]);
1671
  auto e = Compute(
1672
      "custom_nan_to_num",
1673
      outputShape,
1674
      outputStrides,
1675
      [&](const std::vector<VarHandle>& axes) {
1676
        std::vector<ExprHandle> indices(axes.begin(), axes.end());
1677
        auto load = input_buf.load(indices);
1678
        return IfThenElse::make(Cast::make(kBool, isnan(load)), 0.0f, load);
1679
      });
1680
  return e;
1681
}
1682

1683
TEST_F(Kernel, CustomLowering) {
1684
  const auto graph_string = R"IR(
1685
      graph(%x : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu)):
1686
          %none : NoneType = prim::Constant()
1687
          %y : Float(2, 2, strides=[2, 1], requires_grad=0, device=cpu) = aten::nan_to_num(%x, %none, %none, %none)
1688
          return (%y)
1689
)IR";
1690
  auto graph = std::make_shared<Graph>();
1691
  parseIR(graph_string, &*graph);
1692

1693
  std::unordered_map<c10::Symbol, NNCLoweringFunction> lowerings = {
1694
      {aten::nan_to_num, lowerNanToNum}};
1695
  TensorExprKernel k(graph, lowerings);
1696

1697
  auto stmt = k.getCodeGenStmt();
1698
  std::ostringstream oss;
1699
  oss << *stmt;
1700

1701
  // Check that our custom lowering is actually used
1702
  torch::jit::testing::FileCheck().check("custom_nan_to_num")->run(oss.str());
1703
  torch::jit::testing::FileCheck().check("isnan")->run(oss.str());
1704
}
1705

1706
TEST_F(Kernel, Vectorize) {
1707
#ifdef TORCH_ENABLE_LLVM
1708
  const auto graph_string = R"IR(
1709
      graph(%0 : Float(100, 16, strides=[16, 1], device=cpu),
1710
            %1 : Float(100, 16, strides=[16, 1], device=cpu)):
1711
        %2 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %1)
1712
        %3 : Float(100, 16, strides=[16, 1]) = aten::mul(%0, %2)
1713
        return (%3))IR";
1714
  auto graph = std::make_shared<Graph>();
1715
  parseIR(graph_string, &*graph);
1716

1717
  auto a = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1718
  auto b = at::rand({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1719
  auto o = at::zeros({100, 16}, TensorOptions(kCPU).dtype(at::kFloat));
1720
  auto ref = a * (a * b);
1721
  TensorExprKernel k(graph);
1722
  std::vector<at::Tensor> inputs = {a, b};
1723
  StmtPtr s = k.getCodeGenStmt();
1724

1725
  std::ostringstream oss;
1726
  oss << *s;
1727

1728
  // Check the IR we produced
1729
  const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
1730
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1731

1732
  std::vector<IValue> stack = fmap<IValue>(inputs);
1733
  k.run(stack);
1734
  o = stack[0].toTensor();
1735
  for (size_t i = 0; i < 100 * 16; i++) {
1736
    TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1737
  }
1738
#endif
1739
}
1740

1741
// TODO: To vectorize loopnest for 100x3 case, we need to flatten loops first.
1742
TEST_F(Kernel, DISABLED_FlattenVectorize) {
1743
#ifdef TORCH_ENABLE_LLVM
1744
  const auto graph_string = R"IR(
1745
      graph(%0 : Float(100, 3, strides=[3, 1], device=cpu),
1746
            %1 : Float(100, 3, strides=[3, 1], device=cpu)):
1747
        %2 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %1)
1748
        %3 : Float(100, 3, strides=[3, 1]) = aten::mul(%0, %2)
1749
        return (%3))IR";
1750
  auto graph = std::make_shared<Graph>();
1751
  parseIR(graph_string, &*graph);
1752

1753
  auto a = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1754
  auto b = at::rand({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1755
  auto o = at::zeros({100, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1756
  auto ref = a * (a * b);
1757
  TensorExprKernel k(graph);
1758
  std::vector<at::Tensor> inputs = {a, b};
1759
  StmtPtr s = k.getCodeGenStmt();
1760

1761
  std::ostringstream oss;
1762
  oss << *s;
1763

1764
  // Check the IR we produced
1765
  const std::string& verification_pattern = R"IR(# CHECK: Ramp)IR";
1766
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1767

1768
  std::vector<IValue> stack = fmap<IValue>(inputs);
1769
  k.run(stack);
1770
  o = stack[0].toTensor();
1771
  for (size_t i = 0; i < 100 * 3; i++) {
1772
    TORCH_CHECK_EQ(((float*)o.data_ptr())[i], ((float*)ref.data_ptr())[i]);
1773
  }
1774
#endif
1775
}
1776

1777
TEST_F(Kernel, Strided1dWithinBounds) {
1778
  auto ir = R"IR(
1779
    graph(%0 : Float(3, strides=[1], device=cpu),
1780
          %1 : Float(3, strides=[2], device=cpu)):
1781
        %2 : int = prim::Constant[value=1]()
1782
        %3 : Float(3, strides=[1]) = aten::add(%0, %1, %2)
1783
        return (%3))IR";
1784
  auto graph = std::make_shared<Graph>();
1785
  std::unordered_map<std::string, Value*> vmap;
1786
  parseIR(ir, graph.get(), vmap);
1787
  TensorExprKernel k(graph);
1788

1789
  auto a = at::rand({3}, TensorOptions(kCPU).dtype(at::kFloat));
1790
  auto b = at::rand({6}, TensorOptions(kCPU).dtype(at::kFloat))
1791
               .index({Slice(None, None, 2)});
1792
  auto expect = a + b;
1793

1794
  std::vector<at::Tensor> inputs = {a, b};
1795

1796
  std::vector<IValue> stack = fmap<IValue>(inputs);
1797
  k.run(stack);
1798

1799
  auto output = stack[0].toTensor();
1800

1801
  for (size_t i = 0; i < 3; ++i) {
1802
    TORCH_CHECK_EQ(
1803
        ((float*)output.data_ptr())[i], ((float*)expect.data_ptr())[i]);
1804
  }
1805
}
1806

1807
TEST_F(Kernel, InputAsOutput) {
1808
  const auto graph_string = R"IR(
1809
      graph(%x : Float(5, 3, strides=[3, 1], device=cpu),
1810
            %y : Float(5, 3, strides=[1, 5], device=cpu)):
1811
        return (%x, %y))IR";
1812
  auto graph = std::make_shared<Graph>();
1813
  parseIR(graph_string, &*graph);
1814

1815
  auto x = at::rand({5, 3}, TensorOptions(kCPU).dtype(at::kFloat));
1816
  auto y =
1817
      at::rand({3, 5}, TensorOptions(kCPU).dtype(at::kFloat)).transpose(0, 1);
1818
  TensorExprKernel k(graph);
1819
  std::vector<at::Tensor> inputs = {x, y};
1820

1821
  std::vector<IValue> stack = fmap<IValue>(inputs);
1822
  k.run(stack);
1823
  CHECK(at::allclose(x, stack[0].toTensor()));
1824
  CHECK(at::allclose(y, stack[1].toTensor()));
1825
}
1826

1827
TEST_F(Kernel, ScalarOut) {
1828
  auto ir = R"IR(
1829
graph(%x : int, %y : int):
1830
  %z : int = aten::mul(%x, %y)
1831
  %r : int = aten::mul(%z, %x)
1832
  return (%r, %z))IR";
1833
  auto graph = std::make_shared<Graph>();
1834
  std::unordered_map<std::string, Value*> vmap;
1835
  parseIR(ir, graph.get(), vmap);
1836
  TensorExprKernel k(graph);
1837

1838
  auto stmt = k.getCodeGenStmt();
1839
  std::ostringstream oss;
1840
  oss << *stmt;
1841

1842
  // Verify the generated IR. We expect to see a scalar variable (Let) followed
1843
  // by a store to a 0-dim buffer.
1844
  const std::string& verification_pattern = R"IR(
1845
# CHECK: int64_t
1846
# CHECK-NEXT: [0ll] =
1847
# CHECK-NEXT: int64_t
1848
# CHECK-NEXT: [0ll] =
1849
)IR";
1850
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1851

1852
  int64_t x = 2, y = 3, r = 0, z = 0;
1853

1854
  // Verify that TEK::runFast works correctly with scalar outputs
1855
  std::vector<void*> inputs = {&x, &y};
1856
  std::vector<void*> outputs = {&r, &z};
1857
  k.runFast(inputs, outputs);
1858
  TORCH_CHECK_EQ(z, x * y);
1859
  TORCH_CHECK_EQ(r, z * x);
1860

1861
  // Verify that TEK::run works correctly with scalar outputs
1862
  std::vector<IValue> stack = {x, y};
1863
  k.run(stack);
1864
  TORCH_CHECK_EQ(stack[0], x * y * x);
1865
  TORCH_CHECK_EQ(stack[1], x * y);
1866
}
1867

1868
TEST_F(Kernel, ScalarTensorOut) {
1869
  auto ir = R"IR(
1870
graph(%x : int,
1871
      %xt : Long(3, strides=[1], device=cpu),
1872
      %y : int,
1873
      %yt : Long(3, strides=[1], device=cpu)):
1874
  %z : int = aten::mul(%x, %y)
1875
  %r : int = aten::mul(%z, %x)
1876
  %zt : Long(3, strides=[1], device=cpu) = aten::mul(%xt, %y)
1877
  %rt : Long(3, strides=[1], device=cpu) = aten::mul(%zt, %xt)
1878
  return (%r, %rt, %z, %zt))IR";
1879
  auto graph = std::make_shared<Graph>();
1880
  std::unordered_map<std::string, Value*> vmap;
1881
  parseIR(ir, graph.get(), vmap);
1882
  TensorExprKernel k(graph);
1883
  int64_t x = 2, y = 3, r = 0, z = 0;
1884
  auto xt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 2;
1885
  auto yt = at::ones({3}, TensorOptions(kCPU).dtype(at::kLong)) * 3;
1886
  auto zt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong));
1887
  auto rt = at::zeros({3}, TensorOptions(kCPU).dtype(at::kLong));
1888

1889
  // Verify that TEK::runFast works correctly with mixed scalar and tensor
1890
  // inputs/utputs
1891
  std::vector<void*> inputs = {&x, xt.data_ptr(), &y, yt.data_ptr()};
1892
  std::vector<void*> outputs = {&r, rt.data_ptr(), &z, zt.data_ptr()};
1893
  k.runFast(inputs, outputs);
1894
  TORCH_CHECK_EQ(z, x * y);
1895
  TORCH_CHECK_EQ(r, z * x);
1896
  ASSERT_TRUE(at::equal(zt, xt * yt));
1897
  ASSERT_TRUE(at::equal(rt, zt * xt));
1898

1899
  // Verify that TEK::run works correctly with mixed scalar and tensor
1900
  // inputs/utputs
1901
  std::vector<IValue> stack = {x, xt, y, yt};
1902
  k.run(stack);
1903
  TORCH_CHECK_EQ(stack[0], x * y * x);
1904
  ASSERT_TRUE(at::equal(stack[1].toTensor(), xt * yt * xt));
1905
  TORCH_CHECK_EQ(stack[2], x * y);
1906
  ASSERT_TRUE(at::equal(stack[3].toTensor(), xt * yt));
1907
}
1908

1909
TEST_F(Kernel, FuseLoopsWithVariableBounds) {
1910
#ifdef TORCH_ENABLE_LLVM
1911
  bool old_cat_wo_conditionals = getCatWoConditionals();
1912
  getCatWoConditionals() = true;
1913
  const auto graph_string = R"IR(
1914
      graph(%a : Float(SS(-2), 3, SS(-3), requires_grad=0, device=cpu),
1915
            %b : Float(SS(-2), 7, SS(-3), requires_grad=0, device=cpu),
1916
            %c : Float(SS(-2), 9, SS(-3), requires_grad=0, device=cpu),
1917
            %SS_2 : int,
1918
            %SS_3 : int):
1919
        %dim : int = prim::Constant[value=1]()
1920
        %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
1921
        %r : Float(SS(-2), 19, SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim)               # new size: [5,19,2]
1922
        return (%r))IR";
1923
  std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1924
  torch::jit::parseIR(graph_string, graph.get());
1925

1926
  std::vector<int64_t> symbolic_shape_inputs = {-2, -3};
1927

1928
  std::vector<torch::jit::StrideInput> input_desc = {
1929
      torch::jit::StrideInput::TENSOR_CONT};
1930
  std::unordered_map<
1931
      const torch::jit::Value*,
1932
      std::vector<torch::jit::StrideInput>>
1933
      symbolic_strides;
1934
  symbolic_strides[graph->inputs().at(0)] = input_desc;
1935
  symbolic_strides[graph->inputs().at(1)] = input_desc;
1936
  symbolic_strides[graph->inputs().at(2)] = input_desc;
1937
  symbolic_strides[graph->outputs().at(0)] = input_desc;
1938

1939
  TensorExprKernel kernel(
1940
      graph, {}, symbolic_shape_inputs, false, symbolic_strides);
1941

1942
  std::ostringstream oss;
1943
  oss << *kernel.getCodeGenStmt();
1944
  const std::string& verification_pattern =
1945
      R"IR(
1946
# CHECK: for (int64_t i
1947
# CHECK-NEXT: for (int64_t j
1948
# CHECK-NEXT: for (int64_t k
1949
# CHECK: for (int64_t j
1950
# CHECK-NEXT: for (int64_t k
1951
# CHECK: for (int64_t j
1952
# CHECK-NEXT: for (int64_t k
1953
# CHECK-NOT: for (int64_t i
1954
      )IR";
1955
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1956

1957
  auto run_kernel = [&](int dim1, int dim2) {
1958
    auto a =
1959
        at::rand({dim1, 3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1960
    auto b =
1961
        at::rand({dim1, 7, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1962
    auto c =
1963
        at::rand({dim1, 9, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
1964

1965
    auto ref = at::cat({a, b, c}, 1);
1966

1967
    std::vector<IValue> stack =
1968
        fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
1969
    stack.emplace_back(dim1);
1970
    stack.emplace_back(dim2);
1971
    kernel.run(stack);
1972

1973
    auto o = stack[0].toTensor();
1974
    ASSERT_TRUE(at::allclose(o, ref));
1975
  };
1976

1977
  run_kernel(10, 20);
1978
  getCatWoConditionals() = old_cat_wo_conditionals;
1979
#endif
1980
}
1981

1982
TEST_F(Kernel, FuseLoopsWithVariableConcatDim) {
1983
#ifdef TORCH_ENABLE_LLVM
1984
  bool old_cat_wo_conditionals = getCatWoConditionals();
1985
  getCatWoConditionals() = true;
1986
  const auto graph_string = R"IR(
1987
      graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1988
            %b : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1989
            %c : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
1990
            %SS_2 : int,
1991
            %SS_3 : int,
1992
            %SS_4 : int,
1993
            %SS_5 : int):
1994
        %dim : int = prim::Constant[value=1]()
1995
        %inputs : Tensor[] = prim::ListConstruct(%a, %b, %c)
1996
        %r : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim)               # new size: [5,19,2]
1997
        return (%r))IR";
1998
  std::shared_ptr<Graph> graph = std::make_shared<Graph>();
1999
  torch::jit::parseIR(graph_string, graph.get());
2000

2001
  std::vector<int64_t> symbolic_shape_inputs = {-2, -3, -4, -5};
2002

2003
  std::vector<torch::jit::StrideInput> input_desc = {
2004
      torch::jit::StrideInput::TENSOR_CONT};
2005
  std::unordered_map<
2006
      const torch::jit::Value*,
2007
      std::vector<torch::jit::StrideInput>>
2008
      symbolic_strides;
2009
  symbolic_strides[graph->inputs().at(0)] = input_desc;
2010
  symbolic_strides[graph->inputs().at(1)] = input_desc;
2011
  symbolic_strides[graph->inputs().at(2)] = input_desc;
2012
  symbolic_strides[graph->outputs().at(0)] = input_desc;
2013

2014
  TensorExprKernel kernel(
2015
      graph, {}, symbolic_shape_inputs, false, symbolic_strides);
2016

2017
  std::ostringstream oss;
2018
  oss << *kernel.getCodeGenStmt();
2019
  const std::string& verification_pattern =
2020
      R"IR(
2021
# CHECK: for (int64_t i
2022
# CHECK-NEXT: for (int64_t j
2023
# CHECK-NEXT: for (int64_t k
2024
# CHECK: for (int64_t j
2025
# CHECK-NEXT: for (int64_t k
2026
# CHECK: for (int64_t j
2027
# CHECK-NEXT: for (int64_t k
2028
# CHECK-NOT: for (int64_t i
2029
      )IR";
2030
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2031

2032
  auto run_kernel = [&](int dim1, int dim2, int dim3) {
2033
    auto a =
2034
        at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2035
    auto b =
2036
        at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2037
    auto c =
2038
        at::rand({dim1, dim3, dim2}, at::TensorOptions(kCPU).dtype(at::kFloat));
2039

2040
    auto ref = at::cat({a, b, c}, 1);
2041

2042
    std::vector<IValue> stack =
2043
        fmap<IValue>(std::vector<at::Tensor>({a, b, c}));
2044
    stack.emplace_back(dim1);
2045
    stack.emplace_back(dim2);
2046
    stack.emplace_back(dim3);
2047
    stack.emplace_back(3 * dim3);
2048
    kernel.run(stack);
2049

2050
    auto o = stack[0].toTensor();
2051
    ASSERT_TRUE(at::allclose(o, ref));
2052
  };
2053

2054
  run_kernel(10, 20, 15);
2055
  getCatWoConditionals() = old_cat_wo_conditionals;
2056
#endif
2057
}
2058

2059
TEST_F(Kernel, DoNotFuseLoopsWithMismatchingVariableDims) {
2060
#ifdef TORCH_ENABLE_LLVM
2061
  bool old_cat_wo_conditionals = getCatWoConditionals();
2062
  getCatWoConditionals() = true;
2063
  const auto graph_string = R"IR(
2064
      graph(%a : Float(SS(-2), SS(-4), SS(-3), requires_grad=0, device=cpu),
2065
            %b : Float(SS(-2), SS(-5), SS(-3), requires_grad=0, device=cpu),
2066
            %SS_2 : int,
2067
            %SS_3 : int,
2068
            %SS_4 : int,
2069
            %SS_5 : int,
2070
            %SS_6 : int):
2071
        %dim : int = prim::Constant[value=1]()
2072
        %inputs : Tensor[] = prim::ListConstruct(%a, %b)
2073
        %r : Float(SS(-2), SS(-6), SS(-3), requires_grad=0, device=cpu) = aten::cat(%inputs, %dim)               # new size: [5,19,2]
2074
        return (%r))IR";
2075
  std::shared_ptr<Graph> graph = std::make_shared<Graph>();
2076
  torch::jit::parseIR(graph_string, graph.get());
2077

2078
  std::vector<int64_t> symbolic_shape_inputs = {-2, -3, -4, -5, -6};
2079

2080
  std::vector<torch::jit::StrideInput> input_desc = {
2081
      torch::jit::StrideInput::TENSOR_CONT};
2082
  std::unordered_map<
2083
      const torch::jit::Value*,
2084
      std::vector<torch::jit::StrideInput>>
2085
      symbolic_strides;
2086
  symbolic_strides[graph->inputs().at(0)] = input_desc;
2087
  symbolic_strides[graph->inputs().at(1)] = input_desc;
2088
  symbolic_strides[graph->outputs().at(0)] = input_desc;
2089

2090
  TensorExprKernel kernel(
2091
      graph, {}, symbolic_shape_inputs, false, symbolic_strides);
2092

2093
  std::ostringstream oss;
2094
  oss << *kernel.getCodeGenStmt();
2095
  const std::string& verification_pattern =
2096
      R"IR(
2097
# CHECK: for (int64_t i
2098
# CHECK-NEXT: for (int64_t j
2099
# CHECK-NEXT: for (int64_t k
2100
# CHECK: for (int64_t j
2101
# CHECK-NEXT: for (int64_t k
2102
# CHECK-NOT: for (int64_t j
2103
# CHECK-NOT: for (int64_t i
2104
      )IR";
2105
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2106

2107
  auto run_kernel = [&](int dim2, int dim3, int dim4, int dim5) {
2108
    auto a =
2109
        at::rand({dim2, dim4, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat));
2110
    auto b =
2111
        at::rand({dim2, dim5, dim3}, at::TensorOptions(kCPU).dtype(at::kFloat));
2112

2113
    auto ref = at::cat({a, b}, 1);
2114

2115
    std::vector<IValue> stack = fmap<IValue>(std::vector<at::Tensor>({a, b}));
2116
    stack.emplace_back(dim2);
2117
    stack.emplace_back(dim3);
2118
    stack.emplace_back(dim4);
2119
    stack.emplace_back(dim5);
2120
    stack.emplace_back(dim4 + dim5);
2121
    kernel.run(stack);
2122

2123
    auto o = stack[0].toTensor();
2124
    ASSERT_TRUE(at::allclose(o, ref));
2125
  };
2126

2127
  run_kernel(10, 20, 15, 8);
2128
  getCatWoConditionals() = old_cat_wo_conditionals;
2129
#endif
2130
}
2131

2132
} // namespace jit
2133
} // namespace torch
2134

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.