pytorch

Форк
0
/
test_llvm.cpp 
1799 строк · 49.8 Кб
1
#ifdef TORCH_ENABLE_LLVM
2
#include <gtest/gtest.h>
3

4
#include <test/cpp/tensorexpr/test_base.h>
5

6
#include <c10/util/irange.h>
7
#include <test/cpp/tensorexpr/padded_buffer.h>
8
#include <test/cpp/tensorexpr/test_utils.h>
9
#include <torch/csrc/jit/tensorexpr/eval.h>
10
#include <torch/csrc/jit/tensorexpr/ir.h>
11
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
12
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
13
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
14
#include <torch/csrc/jit/tensorexpr/loopnest.h>
15
#include <torch/csrc/jit/tensorexpr/tensor.h>
16
#include <torch/csrc/jit/testing/file_check.h>
17

18
#include <cmath>
19
#include <numeric>
20

21
namespace torch {
22
namespace jit {
23
using namespace torch::jit::tensorexpr;
24

25
using LLVMExprEval = ExprEval<LLVMCodeGen>;
26

27
// Typed tests, can't use gtest params here due to the way we instantiate tests.
28
#define TEST_LLVM_SCALAR_TYPES(_) \
29
  _(uint8_t, Byte, 24)            \
30
  _(int8_t, Char, -20)            \
31
  _(int16_t, Short, 3332)         \
32
  _(int, Int, 123456)             \
33
  _(int64_t, Long, 2631563121321) \
34
  _(float, Float, 0.122)          \
35
  _(double, Double, 0.21312)      \
36
  _(at::Half, Half, 0.128f)
37

38
#define IMM_TEST(Type, Name, Val)                  \
39
  TEST(LLVM, Name##ImmTest) {                      \
40
    auto a = Name##Imm::make(Val);                 \
41
    LLVMExprEval cg(a);                            \
42
    if (std::is_floating_point<decltype(Val)>()) { \
43
      ASSERT_NEAR(cg.value<Type>(), Val, 0.1);     \
44
    } else {                                       \
45
      ASSERT_EQ(cg.value<Type>(), Val);            \
46
    }                                              \
47
  }
48
TEST_LLVM_SCALAR_TYPES(IMM_TEST)
49
#undef IMM_TEST
50

51
#define ADD_TEST(Type, Name, Val)                  \
52
  TEST(LLVM, Name##AddTest) {                      \
53
    auto a = Name##Imm::make(Val);                 \
54
    auto b = Name##Imm::make(Val * 2);             \
55
    auto c = Add::make(a, b);                      \
56
    LLVMExprEval cg(c);                            \
57
    if (std::is_floating_point<decltype(Val)>()) { \
58
      ASSERT_NEAR(cg.value<Type>(), Val * 3, 0.1); \
59
    } else {                                       \
60
      ASSERT_EQ(cg.value<Type>(), Val * 3);        \
61
    }                                              \
62
  }
63
TEST_LLVM_SCALAR_TYPES(ADD_TEST)
64
#undef ADD_TEST
65

66
#define SUB_TEST(Type, Name, Val)                  \
67
  TEST(LLVM, Name##SubTest) {                      \
68
    auto a = Name##Imm::make(Val * 2);             \
69
    auto b = Name##Imm::make(Val);                 \
70
    auto c = Sub::make(a, b);                      \
71
    LLVMExprEval cg(c);                            \
72
    if (std::is_floating_point<decltype(Val)>()) { \
73
      ASSERT_NEAR(cg.value<Type>(), Val, 0.1);     \
74
    } else {                                       \
75
      ASSERT_EQ(cg.value<Type>(), Val);            \
76
    }                                              \
77
  }
78
TEST_LLVM_SCALAR_TYPES(SUB_TEST)
79
#undef SUB_TEST
80

81
#define MUL_TEST(Type, Name, Val)                  \
82
  TEST(LLVM, Name##MulTest) {                      \
83
    auto a = Name##Imm::make(Val);                 \
84
    auto b = Name##Imm::make((Type)4);             \
85
    auto c = Mul::make(a, b);                      \
86
    LLVMExprEval cg(c);                            \
87
    if (std::is_floating_point<decltype(Val)>()) { \
88
      ASSERT_NEAR(cg.value<Type>(), Val * 4, 0.1); \
89
    } else {                                       \
90
      ASSERT_EQ(cg.value<Type>(), Val * 4);        \
91
    }                                              \
92
  }
93
TEST_LLVM_SCALAR_TYPES(MUL_TEST)
94
#undef MUL_TEST
95

96
#define DIV_TEST(Type, Name, Val)                  \
97
  TEST(LLVM, Name##DivTest) {                      \
98
    auto a = Name##Imm::make((Type)6);             \
99
    auto b = Name##Imm::make((Type)3);             \
100
    auto c = Div::make(a, b);                      \
101
    LLVMExprEval cg(c);                            \
102
    if (std::is_floating_point<decltype(Val)>()) { \
103
      ASSERT_NEAR(cg.value<Type>(), 2, 0.1);       \
104
    } else {                                       \
105
      ASSERT_EQ(cg.value<Type>(), 2);              \
106
    }                                              \
107
  }
108
TEST_LLVM_SCALAR_TYPES(DIV_TEST)
109
#undef DIV_TEST
110

111
TEST(LLVM, IntToFloatCastTest) {
112
  auto a = IntImm::make(2);
113
  auto b = Cast::make(kFloat, a);
114
  LLVMExprEval cg(b, {});
115
  ASSERT_EQ(cg.value<float>(), 2.0);
116
}
117

118
TEST(LLVM, FloatToIntCastTest) {
119
  auto a = FloatImm::make(2.0);
120
  auto b = Cast::make(kInt, a);
121
  LLVMExprEval cg(b);
122
  ASSERT_EQ(cg.value<int>(), 2);
123
}
124

125
TEST(LLVM, IntToLongCastTest) {
126
  auto a = IntImm::make(12345);
127
  auto b = Cast::make(kLong, a);
128
  LLVMExprEval cg(b);
129
  ASSERT_EQ(cg.value<int64_t>(), 12345);
130
}
131

132
TEST(LLVM, ByteToCharCastTest) {
133
  auto a = ByteImm::make(250);
134
  auto b = Cast::make(kChar, a);
135
  LLVMExprEval cg(b);
136
  ASSERT_EQ(cg.value<int8_t>(), (int8_t)250);
137
}
138

139
TEST(LLVM, HalfToLongCastTest) {
140
  auto a = HalfImm::make(2.0);
141
  auto b = Cast::make(kLong, a);
142
  LLVMExprEval cg(b);
143
  ASSERT_EQ(cg.value<int64_t>(), 2);
144
}
145

146
TEST(LLVM, ByteToDoubleCastTest) {
147
  auto a = ByteImm::make(2);
148
  auto b = Cast::make(kDouble, a);
149
  LLVMExprEval cg(b);
150
  ASSERT_EQ(cg.value<double>(), 2);
151
}
152

153
TEST(LLVM, FloatToByteCastTest) {
154
  auto a = FloatImm::make(254.0);
155
  auto b = Cast::make(kByte, a);
156
  LLVMExprEval cg(b);
157
  ASSERT_EQ(cg.value<uint8_t>(), 254);
158
}
159

160
TEST(LLVM, FloatToCharCastTest) {
161
  auto a = FloatImm::make(-2.0);
162
  auto b = Cast::make(kChar, a);
163
  LLVMExprEval cg(b);
164
  ASSERT_EQ(cg.value<int8_t>(), -2);
165
}
166

167
TEST(LLVM, ByteToFloatCastTest) {
168
  auto a = ByteImm::make(254);
169
  auto b = Cast::make(kFloat, a);
170
  LLVMExprEval cg(b);
171
  ASSERT_EQ(cg.value<float>(), 254.0);
172
}
173

174
TEST(LLVM, CharToFloatCastTest) {
175
  auto a = CharImm::make(-2);
176
  auto b = Cast::make(kFloat, a);
177
  LLVMExprEval cg(b);
178
  ASSERT_EQ(cg.value<float>(), -2.0);
179
}
180

181
TEST(LLVM, BitCast) {
182
  /* constexpr int16_t ref16 = 1337; */
183
  constexpr int32_t ref32 = 1337;
184
  constexpr int64_t ref64 = 1337;
185
  constexpr float reff32 = 1337.0f;
186
  constexpr double reff64 = 1337.0f;
187

188
  // this is broken
189
  /*{
190
    at::Half k_;
191
    at::Half* k = &k_;
192
    *reinterpret_cast<int16_t*>(k) = ref16;
193
    auto a = HalfImm::make(k);
194
    auto b = BitCast::make(kShort, a);
195
    LLVMExprEval cg(b);
196
    ASSERT_EQ(cg.value<int16_t>(), ref16);
197
  }*/
198

199
  {
200
    float k = raw_bitcast<float>(ref32);
201
    auto a = FloatImm::make(k);
202
    auto b = BitCast::make(kInt, a);
203
    LLVMExprEval cg(b);
204
    ASSERT_EQ(cg.value<int32_t>(), ref32);
205
  }
206

207
  {
208
    double k = raw_bitcast<double>(ref64);
209
    auto a = DoubleImm::make(k);
210
    auto b = BitCast::make(kLong, a);
211
    LLVMExprEval cg(b);
212
    ASSERT_EQ(cg.value<int64_t>(), ref64);
213
  }
214

215
  {
216
    int64_t k = raw_bitcast<int64_t>(reff64);
217
    auto a = LongImm::make(k);
218
    auto b = BitCast::make(kDouble, a);
219
    LLVMExprEval cg(b);
220
    ASSERT_EQ(cg.value<double>(), reff64);
221
  }
222

223
  {
224
    int32_t k = raw_bitcast<int32_t>(reff32);
225
    auto a = IntImm::make(k);
226
    auto b = BitCast::make(kFloat, a);
227
    LLVMExprEval cg(b);
228
    ASSERT_EQ(cg.value<float>(), reff32);
229
  }
230
}
231

232
TEST(LLVM, fastLogFloat) {
233
  const int kTotalSize = 128 * 128;
234
  BufHandle a_buf("A", {ExprHandle(kTotalSize)}, kFloat);
235
  BufHandle b_buf("B", {ExprHandle(kTotalSize)}, kFloat);
236

237
  VarHandle index = VarHandle("index", kInt);
238
  ExprHandle load_a = a_buf.load(index);
239
  StmtPtr store_b = b_buf.store({index}, fast_log(load_a));
240
  StmtPtr stmt = For::make(index, 0, kTotalSize, store_b);
241

242
  PaddedBuffer<float> a_v(kTotalSize);
243
  PaddedBuffer<float> b_v(kTotalSize);
244

245
  for (const auto i : c10::irange(kTotalSize)) {
246
    a_v(i) = at::randn({1}).item().to<float>();
247
  }
248

249
  LLVMCodeGen ir_eval(stmt, {a_buf, b_buf});
250
  ir_eval.call({a_v, b_v});
251

252
  for (const auto i : c10::irange(kTotalSize)) {
253
    auto test = b_v(i);
254
    auto ref = std::log(a_v(i));
255
    if (std::isnan(ref)) {
256
      ASSERT_EQ(std::isnan(test), true);
257
    } else {
258
      ASSERT_FLOAT_EQ(test, ref);
259
    }
260
  }
261
}
262

263
TEST(LLVM, LetTest01) {
264
  BufHandle a("A", {1}, kFloat);
265
  std::vector<float> v = {1, 0};
266
  std::vector<void*> args({v.data()});
267
  VarHandle x("x", kFloat);
268
  auto block = Block::make({
269
      Let::make(x, 3.f),
270
      a.store({0}, ExprHandle(2.f) + (x * ExprHandle(3.f) + ExprHandle(4.f))),
271
  });
272

273
  LLVMCodeGen cg(block, {a});
274
  ASSERT_EQ(cg.value<int>(args), 0);
275
  ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 4.f);
276
}
277

278
TEST(LLVM, LetTest02) {
279
  BufHandle a("A", {1}, kFloat);
280
  std::vector<float> v = {1, 0};
281
  std::vector<void*> args({v.data()});
282
  VarHandle x("x", kFloat);
283
  VarHandle y("y", kFloat);
284
  auto block = Block::make(
285
      {Let::make(x, 3.f),
286
       Let::make(y, 6.f),
287
       a.store(
288
           {IntImm::make(0)},
289
           ExprHandle(2.f) + (x * ExprHandle(3.f) + y * ExprHandle(4.f)))});
290

291
  LLVMCodeGen cg(block, {a});
292
  ASSERT_EQ(cg.value<int>(args), 0);
293
  ASSERT_EQ(v[0], 2.f + 3.f * 3.f + 6.f * 4.f);
294
}
295

296
TEST(LLVM, LetTestMultitype) {
297
  BufHandle a("A", {1}, kDouble);
298
  std::vector<double> v = {1, 0};
299
  std::vector<void*> args({v.data()});
300
  VarHandle x("x", kByte);
301
  VarHandle y("y", kHalf);
302
  auto block = Block::make(
303
      {Let::make(x, 3),
304
       Let::make(y, 6.f),
305
       a.store(
306
           {0},
307
           Cast::make(
308
               kDouble,
309
               ExprHandle(2.f) +
310
                   (x * ExprHandle(3.f) + y * ExprHandle(4.f))))});
311

312
  LLVMCodeGen cg(block, {a});
313
  ASSERT_EQ(cg.value<int>(args), 0);
314
  ASSERT_EQ(v[0], 2.f + 3 * 3.f + 6.f * 4.f);
315
}
316

317
TEST(LLVM, BufferTest) {
318
  BufHandle a("A", {32}, kFloat);
319
  std::vector<int32_t> v(5);
320
  std::vector<void*> args({v.data()});
321
  auto rv = IntImm::make(0);
322
  LLVMExprEval cg(rv, {a});
323
  ASSERT_EQ(cg.value<int>(args), 0);
324
}
325

326
TEST(LLVM, BlockTest) {
327
  BufHandle a("A", {32}, kInt);
328
  std::vector<int32_t> v = {1, 2};
329
  std::vector<void*> args({v.data()});
330

331
  auto block = Block::make({
332
      a.store({0}, 3),
333
      a.store({1}, 4),
334
      a.store({0}, 4),
335
  });
336

337
  LLVMCodeGen cg(block, {a});
338
  ASSERT_EQ(cg.value<int>(args), 0);
339
  ASSERT_EQ(v[0], 4);
340
  ASSERT_EQ(v[1], 4);
341
}
342

343
TEST(LLVM, LoadStoreTest) {
344
  BufHandle a("A", {1}, kInt);
345
  BufHandle b("B", {1}, kInt);
346
  std::vector<int32_t> a_buffer = {42};
347
  std::vector<int32_t> b_buffer = {-11};
348

349
  auto store = b.store({0}, a.load(0));
350
  LLVMCodeGen cg(store, {a, b});
351
  std::vector<void*> args({a_buffer.data(), b_buffer.data()});
352
  ASSERT_EQ(cg.value<int>(args), 0);
353
  ASSERT_EQ(a_buffer[0], 42);
354
  ASSERT_EQ(b_buffer[0], 42);
355
}
356

357
TEST(LLVM, IfThenElseTest) {
358
  BufHandle a("A", {1}, kInt);
359
  BufHandle b("B", {1}, kInt);
360
  BufHandle c("C", {1}, kInt);
361
  std::vector<int32_t> a_buffer = {42};
362
  std::vector<int32_t> b_buffer = {-11};
363
  std::vector<int32_t> c_buffer = {1};
364

365
  auto store = b.store({0}, IfThenElse::make(c.load(0), a.load(0), 0));
366
  LLVMCodeGen cg(store, {a, b, c});
367
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
368
  ASSERT_EQ(cg.value<int>(args), 0);
369
  ASSERT_EQ(a_buffer[0], 42);
370
  ASSERT_EQ(b_buffer[0], 42);
371
}
372

373
// if (x < 10) x = x + 1
374
TEST(LLVM, CondNoFalseBlockTest) {
375
  BufHandle x("X", {1}, kInt);
376
  auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
377
  auto cond = Cond::make(cmp, x.store({0}, x.load(0) + 1), nullptr);
378

379
  for (int32_t x_value : {0, 10, 20}) {
380
    std::vector<int32_t> x_buffer = {x_value};
381
    std::vector<void*> args({x_buffer.data()});
382
    LLVMCodeGen cg(cond, {x});
383
    ASSERT_EQ(cg.value<int>(args), 0);
384
    if (x_value < 10) {
385
      ASSERT_EQ(x_buffer[0], x_value + 1);
386
    } else {
387
      ASSERT_EQ(x_buffer[0], x_value);
388
    }
389
  }
390
}
391

392
// if (x < 10) {
393
//   x = x + 1;
394
// } else {
395
//   x = x - 1;
396
// }
397
TEST(LLVM, CondTest) {
398
  BufHandle x("X", {1}, kInt);
399
  auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
400
  auto cond =
401
      Cond::make(cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));
402
  auto block = Block::make({
403
      cond,
404
      x.store({0}, x.load(0) * 2),
405
  });
406

407
  for (int32_t x_value : {0, 10, 20}) {
408
    std::vector<int32_t> x_buffer = {x_value};
409
    std::vector<void*> args({x_buffer.data()});
410
    LLVMCodeGen cg(block, {x});
411
    ASSERT_EQ(cg.value<int>(args), 0);
412
    if (x_value < 10) {
413
      ASSERT_EQ(x_buffer[0], (x_value + 1) * 2);
414
    } else {
415
      ASSERT_EQ(x_buffer[0], (x_value - 1) * 2);
416
    }
417
  }
418
}
419

420
// if (x < 10) {
421
//   if (x > 5) {
422
//     x = x + 1;
423
//   } else {
424
//     x = x - 1;
425
//   }
426
// } else {
427
//   if (x <= 15) {
428
//     x = x + 2;
429
//   } else {
430
//     x = x - 2;
431
//   }
432
// }
433
TEST(LLVM, CondNestedTest) {
434
  BufHandle x("X", {1}, kInt);
435
  auto true_cmp =
436
      CompareSelect::make(x.load(0), 5, CompareSelectOperation::kGT);
437
  auto true_cond = Cond::make(
438
      true_cmp, x.store({0}, x.load(0) + 1), x.store({0}, x.load(0) - 1));
439
  auto false_cmp =
440
      CompareSelect::make(x.load(0), 15, CompareSelectOperation::kLE);
441
  auto false_cond = Cond::make(
442
      false_cmp, x.store({0}, x.load(0) + 2), x.store({0}, x.load(0) - 2));
443
  auto cmp = CompareSelect::make(x.load(0), 10, CompareSelectOperation::kLT);
444
  auto cond = Cond::make(cmp, true_cond, false_cond);
445

446
  for (int32_t x_value : {0, 8, 15, 20}) {
447
    std::vector<int32_t> x_buffer = {x_value};
448
    std::vector<void*> args({x_buffer.data()});
449
    LLVMCodeGen cg(cond, {x});
450
    ASSERT_EQ(cg.value<int>(args), 0);
451
    if (x_value < 10) {
452
      if (x_value > 5) {
453
        ASSERT_EQ(x_buffer[0], x_value + 1);
454
      } else {
455
        ASSERT_EQ(x_buffer[0], x_value - 1);
456
      }
457
    } else {
458
      if (x_value <= 15) {
459
        ASSERT_EQ(x_buffer[0], x_value + 2);
460
      } else {
461
        ASSERT_EQ(x_buffer[0], x_value - 2);
462
      }
463
    }
464
  }
465
}
466

467
TEST(LLVM, DirectVectorization) {
468
  constexpr int M = 3;
469
  constexpr int N = 64;
470
  BufHandle a("a", {M, N}, kFloat);
471
  BufHandle b("b", {M, N}, kFloat);
472
  BufHandle c("c", {M, N}, kFloat);
473
  VarHandle m("m", kInt);
474
  VarHandle n("n", kInt);
475
  StmtPtr s = For::make(
476
      m,
477
      0,
478
      M,
479
      Store::make(
480
          c,
481
          {Ramp::make(m * 64, 1, 64)},
482
          Load::make({kFloat, 64}, a, {Ramp::make(m * 64, 1, 64)}) *
483
              Load::make({kFloat, 64}, b, {Ramp::make(m * 64, 1, 64)})));
484
  LLVMCodeGen cg(s, {a, b, c});
485
}
486

487
TEST(LLVM, VecLoadStoreTest) {
488
  BufHandle a("A", {1}, kInt);
489
  BufHandle b("B", {1}, kInt);
490
  std::vector<int32_t> a_buffer = {1, 1, 1, 1};
491
  std::vector<int32_t> b_buffer = {2, 2, 2, 2};
492

493
  auto store = b.store({Ramp::make(0, 1, 4)}, a.load({Ramp::make(0, 1, 4)}));
494
  LLVMCodeGen cg(store, {a, b});
495
  std::vector<void*> args({a_buffer.data(), b_buffer.data()});
496
  ASSERT_EQ(cg.value<int>(args), 0);
497
  ASSERT_EQ(a_buffer[0], 1);
498
  ASSERT_EQ(a_buffer[1], 1);
499
  ASSERT_EQ(a_buffer[2], 1);
500
  ASSERT_EQ(a_buffer[3], 1);
501
  ASSERT_EQ(b_buffer[0], 1);
502
  ASSERT_EQ(b_buffer[1], 1);
503
  ASSERT_EQ(b_buffer[2], 1);
504
  ASSERT_EQ(b_buffer[3], 1);
505
}
506

507
#define FLOAT_INTRINSICS_TEST(Name, Lanes)                                   \
508
  TEST(LLVM, VecFloat_##Name##Lane##Lanes##Test) {                           \
509
    BufHandle a("A", {1}, kFloat);                                           \
510
    BufHandle b("B", {1}, kFloat);                                           \
511
    float val = 0.5f;                                                        \
512
    std::vector<float> a_buffer(Lanes, val);                                 \
513
    std::vector<float> b_buffer(Lanes, val);                                 \
514
    auto store = b.store(                                                    \
515
        {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \
516
    LLVMCodeGen cg(store, {a, b});                                           \
517
    std::vector<void*> args({a_buffer.data(), b_buffer.data()});             \
518
    ASSERT_EQ(cg.value<int>(args), 0);                                       \
519
    for (const auto i : c10::irange(Lanes)) {                                \
520
      ASSERT_FLOAT_EQ(a_buffer[i], val);                                     \
521
    }                                                                        \
522
  } // namespace jit
523
FLOAT_INTRINSICS_TEST(erf, 4)
524
FLOAT_INTRINSICS_TEST(erfc, 4)
525
FLOAT_INTRINSICS_TEST(acos, 4)
526
FLOAT_INTRINSICS_TEST(asin, 4)
527
FLOAT_INTRINSICS_TEST(atan, 4)
528
FLOAT_INTRINSICS_TEST(cosh, 4)
529
FLOAT_INTRINSICS_TEST(sinh, 4)
530
FLOAT_INTRINSICS_TEST(tanh, 4)
531
FLOAT_INTRINSICS_TEST(expm1, 4)
532
FLOAT_INTRINSICS_TEST(lgamma, 4)
533
FLOAT_INTRINSICS_TEST(erf, 8)
534
FLOAT_INTRINSICS_TEST(erfc, 8)
535
FLOAT_INTRINSICS_TEST(acos, 8)
536
FLOAT_INTRINSICS_TEST(asin, 8)
537
FLOAT_INTRINSICS_TEST(atan, 8)
538
FLOAT_INTRINSICS_TEST(cosh, 8)
539
FLOAT_INTRINSICS_TEST(sinh, 8)
540
FLOAT_INTRINSICS_TEST(tanh, 8)
541
FLOAT_INTRINSICS_TEST(expm1, 8)
542
FLOAT_INTRINSICS_TEST(lgamma, 8)
543
#undef FLOAT_INTRINSICS_TEST
544

545
#define DOUBLE_INTRINSICS_TEST(Name, Lanes)                                  \
546
  TEST(LLVM, VecDouble_##Name##Lane##Lanes##Test) {                          \
547
    BufHandle a("A", {1}, kDouble);                                          \
548
    BufHandle b("B", {1}, kDouble);                                          \
549
    float val = 0.5f;                                                        \
550
    std::vector<double> a_buffer(Lanes, val);                                \
551
    std::vector<double> b_buffer(Lanes, val);                                \
552
    auto store = b.store(                                                    \
553
        {Ramp::make(0, 1, Lanes)}, Name(a.load({Ramp::make(0, 1, Lanes)}))); \
554
    LLVMCodeGen cg(store, {a, b});                                           \
555
    std::vector<void*> args({a_buffer.data(), b_buffer.data()});             \
556
    ASSERT_EQ(cg.value<int>(args), 0);                                       \
557
    for (const auto i : c10::irange(Lanes)) {                                \
558
      ASSERT_FLOAT_EQ(a_buffer[i], val);                                     \
559
    }                                                                        \
560
  } // namespace jit
561
DOUBLE_INTRINSICS_TEST(erf, 2)
562
DOUBLE_INTRINSICS_TEST(erfc, 2)
563
DOUBLE_INTRINSICS_TEST(acos, 2)
564
DOUBLE_INTRINSICS_TEST(asin, 2)
565
DOUBLE_INTRINSICS_TEST(atan, 2)
566
DOUBLE_INTRINSICS_TEST(cosh, 2)
567
DOUBLE_INTRINSICS_TEST(sinh, 2)
568
DOUBLE_INTRINSICS_TEST(tanh, 2)
569
DOUBLE_INTRINSICS_TEST(expm1, 2)
570
DOUBLE_INTRINSICS_TEST(lgamma, 2)
571
DOUBLE_INTRINSICS_TEST(erf, 4)
572
DOUBLE_INTRINSICS_TEST(erfc, 4)
573
DOUBLE_INTRINSICS_TEST(acos, 4)
574
DOUBLE_INTRINSICS_TEST(asin, 4)
575
DOUBLE_INTRINSICS_TEST(atan, 4)
576
DOUBLE_INTRINSICS_TEST(cosh, 4)
577
DOUBLE_INTRINSICS_TEST(sinh, 4)
578
DOUBLE_INTRINSICS_TEST(tanh, 4)
579
DOUBLE_INTRINSICS_TEST(expm1, 4)
580
DOUBLE_INTRINSICS_TEST(lgamma, 4)
581
#undef DOUBLE_INTRINSICS_TEST
582

583
TEST(LLVM, VectorizerLoadStoreTest) {
584
  BufHandle a("A", {1}, kInt);
585

586
  Tensor c = Compute("c", {4}, [&](const VarHandle& i) { return a.load(i); });
587

588
  BufHandle c_buf(c.buf());
589
  LoopNest l({c});
590
  StmtPtr s = l.root_stmt();
591
  ASSERT_TRUE(LoopNest::vectorize(to<For>(to<Block>(s)->front())));
592

593
  ASSERT_TRUE(to<For>(to<Block>(s)->front()) == nullptr);
594

595
  LLVMCodeGen cg(s, {a, c_buf});
596

597
  std::vector<int> a_vec(4, 21);
598
  std::vector<int> c_vec(4, 0);
599
  std::vector<void*> args({a_vec.data(), c_vec.data()});
600
  ASSERT_EQ(cg.value<int>(args), 0);
601
  assertAllEqual(c_vec, 21);
602
}
603

604
TEST(LLVM, VectorizeBitCast) {
605
  BufHandle a("A", {128}, kInt);
606

607
  Tensor c = Compute("c", {128}, [&](const VarHandle& i) {
608
    return bitcast<float>(a.load(i));
609
  });
610

611
  BufHandle c_buf(c.buf());
612
  LoopNest l({c});
613
  StmtPtr s = l.root_stmt();
614
  ASSERT_TRUE(LoopNest::vectorize(to<For>(to<Block>(s)->front())));
615
  ASSERT_TRUE(to<For>(to<Block>(s)->front()) == nullptr);
616

617
  LLVMCodeGen cg(s, {a, c_buf});
618

619
  std::vector<int> a_vec(128);
620
  std::vector<float> c_vec(128);
621
  for (const auto i : c10::irange(128)) {
622
    a_vec[i] = raw_bitcast<int>(1337.f);
623
  }
624
  std::vector<void*> args({a_vec.data(), c_vec.data()});
625
  ASSERT_EQ(cg.value<int>(args), 0);
626
  assertAllEqual(c_vec, 1337.f);
627
}
628

629
TEST(LLVM, MemcpyTest) {
630
  constexpr int N = 32;
631
  BufHandle a("A", {N}, kInt);
632
  BufHandle b("B", {N}, kInt);
633
  std::vector<int32_t> a_buffer(N, 42);
634
  std::vector<int32_t> b_buffer(N, 0);
635

636
  VarHandle i("i", kInt);
637
  auto expr = For::make(i, 0, N, b.store({i}, a.load(i)));
638

639
  LLVMCodeGen cg(expr, {a, b});
640

641
  std::vector<void*> args({a_buffer.data(), b_buffer.data()});
642
  ASSERT_EQ(cg.value<int>(args), 0);
643

644
  ASSERT_EQ(a_buffer.size(), N);
645
  ASSERT_EQ(b_buffer.size(), N);
646
  assertAllEqual(a_buffer, 42);
647
  assertAllEqual(b_buffer, 42);
648
}
649

650
TEST(LLVM, BzeroTest) {
651
  constexpr int N = 32;
652
  BufHandle b("B", {N}, kInt);
653
  std::vector<int32_t> b_buffer(N, 11);
654

655
  VarHandle i("i", kInt);
656
  auto expr = For::make(i, 0, N, b.store({i}, 0));
657

658
  LLVMCodeGen cg(expr, {b});
659

660
  std::vector<void*> args({b_buffer.data()});
661
  ASSERT_EQ(cg.value<int>(args), 0);
662

663
  ASSERT_EQ(b_buffer.size(), N);
664
  assertAllEqual(b_buffer, 0);
665
}
666

667
TEST(LLVM, ElemwiseAdd) {
668
  constexpr int N = 1024;
669
  BufHandle a("A", {N}, kInt);
670
  BufHandle b("B", {N}, kInt);
671
  BufHandle c("C", {N}, kInt);
672
  std::vector<int32_t> a_buffer(N, 41);
673
  std::vector<int32_t> b_buffer(N, 1);
674
  std::vector<int32_t> c_buffer(N, 1);
675

676
  VarHandle i("i", kInt);
677
  auto expr = For::make(i, 0, N, c.store({i}, Add::make(a.load(i), b.load(i))));
678

679
  LLVMCodeGen cg(expr, {a, b, c});
680

681
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
682
  ASSERT_EQ(cg.value<int>(args), 0);
683

684
  ASSERT_EQ(a_buffer.size(), N);
685
  ASSERT_EQ(b_buffer.size(), N);
686
  ASSERT_EQ(c_buffer.size(), N);
687
  assertAllEqual(a_buffer, 41);
688
  assertAllEqual(b_buffer, 1);
689
  assertAllEqual(c_buffer, 42);
690
}
691

692
TEST(LLVM, ElemwiseAddFloat) {
693
  constexpr int N = 1024;
694
  BufHandle a("A", {N}, kFloat);
695
  BufHandle b("B", {N}, kFloat);
696
  BufHandle c("C", {N}, kFloat);
697
  std::vector<float> a_buffer(N, 41);
698
  std::vector<float> b_buffer(N, 1);
699
  std::vector<float> c_buffer(N, 1);
700

701
  VarHandle i("i", kInt);
702
  auto expr = For::make(i, 0, N, c.store({i}, a.load(i) + b.load(i)));
703

704
  LLVMCodeGen cg(expr, {a, b, c});
705

706
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
707
  ASSERT_EQ(cg.value<int>(args), 0);
708

709
  ASSERT_EQ(a_buffer.size(), N);
710
  ASSERT_EQ(b_buffer.size(), N);
711
  ASSERT_EQ(c_buffer.size(), N);
712
  assertAllEqual(a_buffer, 41.0f);
713
  assertAllEqual(b_buffer, 1.0f);
714
  assertAllEqual(c_buffer, 42.0f);
715
}
716

717
TEST(LLVM, ElemwiseLog10Float) {
718
  constexpr int N = 1024;
719
  BufHandle a("A", {N}, kFloat);
720
  BufHandle b("B", {N}, kFloat);
721
  std::vector<float> a_buffer(N, 10.0f);
722
  std::vector<float> b_buffer(N, 2.0f);
723

724
  VarHandle i("i", kInt);
725
  auto expr = For::make(
726
      i,
727
      0,
728
      N / 4,
729
      b.store(
730
          {Ramp::make(i * 4, 1, 4)}, log10(a.load({Ramp::make(i * 4, 1, 4)}))));
731

732
  LLVMCodeGen cg(expr, {a, b});
733

734
  std::vector<void*> args({a_buffer.data(), b_buffer.data()});
735
  ASSERT_EQ(cg.value<int>(args), 0);
736

737
  ASSERT_EQ(a_buffer.size(), N);
738
  ASSERT_EQ(b_buffer.size(), N);
739
  assertAllEqual(a_buffer, 10.0f);
740
  assertAllEqual(b_buffer, 1.0f);
741
}
742

743
TEST(LLVM, ElemwiseLog1pFloat) {
744
  constexpr int N = 1024;
745
  BufHandle a("A", {N}, kFloat);
746
  BufHandle b("B", {N}, kFloat);
747
  std::vector<float> a_buffer(N, expf(3.0f) - 1);
748
  std::vector<float> b_buffer(N, 42.0f);
749

750
  VarHandle i("i", kInt);
751
  auto expr = For::make(
752
      i,
753
      0,
754
      N / 4,
755
      b.store(
756
          {Ramp::make(i * 4, 1, 4)}, log1p(a.load({Ramp::make(i * 4, 1, 4)}))));
757

758
  LLVMCodeGen cg(expr, {a, b});
759

760
  std::vector<void*> args({a_buffer.data(), b_buffer.data()});
761
  ASSERT_EQ(cg.value<int>(args), 0);
762

763
  ASSERT_EQ(a_buffer.size(), N);
764
  ASSERT_EQ(b_buffer.size(), N);
765
  assertAllEqual(a_buffer, expf(3.0f) - 1);
766
  ExpectAllNear(b_buffer, 3.0f, 1e-5f);
767
}
768

769
TEST(LLVM, ElemwiseMaxInt) {
770
  constexpr int N = 1024;
771
  BufHandle a("A", {N}, kInt);
772
  BufHandle b("B", {N}, kInt);
773
  BufHandle c("C", {N}, kInt);
774
  std::vector<int> a_buffer(N, 41);
775
  std::vector<int> b_buffer(N, 1);
776
  std::vector<int> c_buffer(N, 1);
777

778
  VarHandle i("i", kInt);
779
  auto expr =
780
      For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));
781

782
  LLVMCodeGen cg(expr, {a, b, c});
783

784
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
785
  ASSERT_EQ(cg.value<int>(args), 0);
786

787
  ASSERT_EQ(a_buffer.size(), N);
788
  ASSERT_EQ(b_buffer.size(), N);
789
  ASSERT_EQ(c_buffer.size(), N);
790
  assertAllEqual(a_buffer, 41);
791
  assertAllEqual(b_buffer, 1);
792
  assertAllEqual(c_buffer, 41);
793
}
794

795
TEST(LLVM, ElemwiseMinInt) {
796
  constexpr int N = 1024;
797
  BufHandle a("A", {N}, kInt);
798
  BufHandle b("B", {N}, kInt);
799
  BufHandle c("C", {N}, kInt);
800
  std::vector<int> a_buffer(N, 41);
801
  std::vector<int> b_buffer(N, 1);
802
  std::vector<int> c_buffer(N, 1);
803

804
  VarHandle i("i", kInt);
805
  auto expr =
806
      For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));
807

808
  LLVMCodeGen cg(expr, {a, b, c});
809

810
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
811
  ASSERT_EQ(cg.value<int>(args), 0);
812

813
  ASSERT_EQ(a_buffer.size(), N);
814
  ASSERT_EQ(b_buffer.size(), N);
815
  ASSERT_EQ(c_buffer.size(), N);
816
  assertAllEqual(a_buffer, 41);
817
  assertAllEqual(b_buffer, 1);
818
  assertAllEqual(c_buffer, 1);
819
}
820

821
TEST(LLVM, ElemwiseMaxFloat) {
822
  constexpr int N = 1024;
823
  BufHandle a("A", {N}, kFloat);
824
  BufHandle b("B", {N}, kFloat);
825
  BufHandle c("C", {N}, kFloat);
826
  std::vector<float> a_buffer(N, 41);
827
  std::vector<float> b_buffer(N, 1);
828
  std::vector<float> c_buffer(N, 1);
829

830
  VarHandle i("i", kInt);
831
  auto expr =
832
      For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));
833

834
  LLVMCodeGen cg(expr, {a, b, c});
835

836
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
837
  ASSERT_EQ(cg.value<int>(args), 0);
838

839
  ASSERT_EQ(a_buffer.size(), N);
840
  ASSERT_EQ(b_buffer.size(), N);
841
  ASSERT_EQ(c_buffer.size(), N);
842
  assertAllEqual(a_buffer, 41.0f);
843
  assertAllEqual(b_buffer, 1.0f);
844
  assertAllEqual(c_buffer, 41.0f);
845
}
846

847
TEST(LLVM, ElemwiseMaxNaNFloat) {
848
  constexpr int N = 1024;
849
  BufHandle a("A", {N}, kFloat);
850
  BufHandle b("B", {N}, kFloat);
851
  BufHandle c("C", {N}, kFloat);
852
  std::vector<float> a_buffer(N, NAN);
853
  std::vector<float> b_buffer(N, 1);
854
  std::vector<float> c_buffer(N, 1);
855

856
  VarHandle i("i", kInt);
857
  auto expr =
858
      For::make(i, 0, N, c.store({i}, Max::make(a.load(i), b.load(i), false)));
859

860
  LLVMCodeGen cg(expr, {a, b, c});
861

862
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
863
  ASSERT_EQ(cg.value<int>(args), 0);
864

865
  ASSERT_EQ(a_buffer.size(), N);
866
  ASSERT_EQ(b_buffer.size(), N);
867
  ASSERT_EQ(c_buffer.size(), N);
868
  assertAllEqual(b_buffer, 1.0f);
869
  for (auto const& elt : c_buffer) {
870
    ASSERT_TRUE(std::isnan(elt));
871
  }
872
}
873

874
TEST(LLVM, ElemwiseMinFloat) {
875
  constexpr int N = 1024;
876
  BufHandle a("A", {N}, kFloat);
877
  BufHandle b("B", {N}, kFloat);
878
  BufHandle c("C", {N}, kFloat);
879
  std::vector<float> a_buffer(N, 41);
880
  std::vector<float> b_buffer(N, 1);
881
  std::vector<float> c_buffer(N, 1);
882

883
  VarHandle i("i", kInt);
884
  auto expr =
885
      For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));
886

887
  LLVMCodeGen cg(expr, {a, b, c});
888

889
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
890
  ASSERT_EQ(cg.value<int>(args), 0);
891

892
  ASSERT_EQ(a_buffer.size(), N);
893
  ASSERT_EQ(b_buffer.size(), N);
894
  ASSERT_EQ(c_buffer.size(), N);
895
  assertAllEqual(a_buffer, 41.0f);
896
  assertAllEqual(b_buffer, 1.0f);
897
  assertAllEqual(c_buffer, 1.0f);
898
}
899

900
TEST(LLVM, ElemwiseMinNaNFloat) {
901
  constexpr int N = 1024;
902
  BufHandle a("A", {N}, kFloat);
903
  BufHandle b("B", {N}, kFloat);
904
  BufHandle c("C", {N}, kFloat);
905
  std::vector<float> a_buffer(N, NAN);
906
  std::vector<float> b_buffer(N, 1);
907
  std::vector<float> c_buffer(N, 1);
908

909
  VarHandle i("i", kInt);
910
  auto expr =
911
      For::make(i, 0, N, c.store({i}, Min::make(a.load(i), b.load(i), false)));
912

913
  LLVMCodeGen cg(expr, {a, b, c});
914

915
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
916
  ASSERT_EQ(cg.value<int>(args), 0);
917

918
  ASSERT_EQ(a_buffer.size(), N);
919
  ASSERT_EQ(b_buffer.size(), N);
920
  ASSERT_EQ(c_buffer.size(), N);
921
  assertAllEqual(b_buffer, 1.0f);
922
  for (auto const& elt : c_buffer) {
923
    ASSERT_TRUE(std::isnan(elt));
924
  }
925
}
926

927
TEST(LLVM, ElemwiseMod) {
928
  constexpr int N = 1024;
929
  BufHandle a("A", {N}, kInt);
930
  BufHandle b("B", {N}, kInt);
931
  BufHandle c("C", {N}, kInt);
932
  std::vector<int32_t> a_buffer(N, 41);
933
  std::vector<int32_t> b_buffer(N, 23);
934
  std::vector<int32_t> c_buffer(N, 18);
935

936
  VarHandle i("i", kInt);
937
  auto expr = For::make(i, 0, N, c.store({i}, Mod::make(a.load(i), b.load(i))));
938

939
  LLVMCodeGen cg(expr, {a, b, c});
940

941
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
942
  ASSERT_EQ(cg.value<int>(args), 0);
943

944
  ASSERT_EQ(a_buffer.size(), N);
945
  ASSERT_EQ(b_buffer.size(), N);
946
  ASSERT_EQ(c_buffer.size(), N);
947
  assertAllEqual(a_buffer, 41);
948
  assertAllEqual(b_buffer, 23);
949
  assertAllEqual(c_buffer, 18);
950
}
951

952
TEST(LLVM, CompareSelectIntEQ) {
953
  constexpr int N = 1024;
954
  BufHandle a("A", {N}, kInt);
955
  BufHandle b("B", {N}, kInt);
956
  BufHandle c("C", {N}, kInt);
957
  std::vector<int> a_buffer(N, 1);
958
  std::vector<int> b_buffer(N, 1);
959
  std::vector<int> c_buffer(N, 0);
960
  std::vector<int> c_ref(N, 1);
961

962
  for (int i = 0; i < N / 2; i++) {
963
    b_buffer[i] = 0;
964
    c_ref[i] = 0;
965
  }
966

967
  VarHandle i("i", kInt);
968
  auto expr = For::make(
969
      i,
970
      0,
971
      N,
972
      c.store(
973
          {i},
974
          CompareSelect::make(
975
              a.load(i), b.load(i), CompareSelectOperation::kEQ)));
976

977
  LLVMCodeGen cg(expr, {a, b, c});
978

979
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
980
  ASSERT_EQ(cg.value<int>(args), 0);
981

982
  ASSERT_EQ(a_buffer.size(), N);
983
  ASSERT_EQ(b_buffer.size(), N);
984
  ASSERT_EQ(c_buffer.size(), N);
985

986
  assertAllEqual(a_buffer, 1);
987
  for (const auto i : c10::irange(N)) {
988
    ASSERT_EQ(c_ref[i], c_buffer[i]);
989
  }
990
}
991

992
TEST(LLVM, CompareSelectFloatEQ) {
993
  constexpr int N = 1024;
994
  BufHandle a("A", {N}, kFloat);
995
  BufHandle b("B", {N}, kFloat);
996
  BufHandle c("C", {N}, kInt);
997
  std::vector<float> a_buffer(N, 1.0f);
998
  std::vector<float> b_buffer(N, 1.0f);
999
  std::vector<int> c_buffer(N, 0);
1000

1001
  VarHandle i("i", kInt);
1002
  auto expr = For::make(
1003
      i,
1004
      0,
1005
      N,
1006
      c.store(
1007
          {i},
1008
          CompareSelect::make(
1009
              a.load(i), b.load(i), CompareSelectOperation::kEQ)));
1010

1011
  LLVMCodeGen cg(expr, {a, b, c});
1012

1013
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1014
  ASSERT_EQ(cg.value<int>(args), 0);
1015

1016
  ASSERT_EQ(a_buffer.size(), N);
1017
  ASSERT_EQ(b_buffer.size(), N);
1018
  ASSERT_EQ(c_buffer.size(), N);
1019

1020
  assertAllEqual(a_buffer, 1.0f);
1021
  assertAllEqual(b_buffer, 1.0f);
1022
  assertAllEqual(c_buffer, 1);
1023
}
1024

1025
TEST(LLVM, CompareSelectByteGT) {
1026
  constexpr int N = 1024;
1027
  BufHandle a("A", {N}, kByte);
1028
  BufHandle b("B", {N}, kByte);
1029
  BufHandle c("C", {N}, kInt);
1030
  std::vector<uint8_t> a_buffer(N, 0);
1031
  std::vector<uint8_t> b_buffer(N, 0);
1032
  std::vector<int> c_buffer(N, 0);
1033
  std::vector<int> c_ref(N, 0);
1034

1035
  for (int i = 0; i < N / 2; i++) {
1036
    a_buffer[i] = 128;
1037
    c_ref[i] = 1;
1038
  }
1039

1040
  VarHandle i("i", kInt);
1041
  auto expr = For::make(
1042
      i,
1043
      0,
1044
      N,
1045
      c.store(
1046
          {i},
1047
          CompareSelect::make(
1048
              a.load(i), b.load(i), CompareSelectOperation::kGT)));
1049

1050
  LLVMCodeGen cg(expr, {a, b, c});
1051

1052
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1053
  ASSERT_EQ(cg.value<int>(args), 0);
1054

1055
  ASSERT_EQ(a_buffer.size(), N);
1056
  ASSERT_EQ(b_buffer.size(), N);
1057
  ASSERT_EQ(c_buffer.size(), N);
1058

1059
  assertAllEqual(b_buffer, uint8_t(0));
1060
  for (const auto i : c10::irange(N)) {
1061
    ASSERT_EQ(c_ref[i], c_buffer[i]);
1062
  }
1063
}
1064

1065
TEST(LLVM, CompareSelectByteGE) {
1066
  constexpr int N = 1024;
1067
  BufHandle a("A", {N}, kByte);
1068
  BufHandle b("B", {N}, kByte);
1069
  BufHandle c("C", {N}, kInt);
1070
  std::vector<uint8_t> a_buffer(N, 0);
1071
  std::vector<uint8_t> b_buffer(N, 0);
1072
  std::vector<int> c_buffer(N, 0);
1073
  std::vector<int> c_ref(N, 1);
1074

1075
  VarHandle i("i", kInt);
1076
  auto expr = For::make(
1077
      i,
1078
      0,
1079
      N,
1080
      c.store(
1081
          {i},
1082
          CompareSelect::make(
1083
              a.load(i), b.load(i), CompareSelectOperation::kGE)));
1084

1085
  LLVMCodeGen cg(expr, {a, b, c});
1086

1087
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1088
  ASSERT_EQ(cg.value<int>(args), 0);
1089

1090
  ASSERT_EQ(a_buffer.size(), N);
1091
  ASSERT_EQ(b_buffer.size(), N);
1092
  ASSERT_EQ(c_buffer.size(), N);
1093

1094
  assertAllEqual(b_buffer, uint8_t(0));
1095
  for (const auto i : c10::irange(N)) {
1096
    ASSERT_EQ(c_ref[i], c_buffer[i]);
1097
  }
1098
}
1099

1100
TEST(LLVM, CompareSelectByteLT) {
1101
  constexpr int N = 1024;
1102
  BufHandle a("A", {N}, kByte);
1103
  BufHandle b("B", {N}, kByte);
1104
  BufHandle c("C", {N}, kInt);
1105
  std::vector<uint8_t> a_buffer(N, 0);
1106
  std::vector<uint8_t> b_buffer(N, 128);
1107
  std::vector<int> c_buffer(N, 0);
1108
  std::vector<int> c_ref(N, 1);
1109

1110
  for (int i = 0; i < N / 2; i++) {
1111
    a_buffer[i] = 128;
1112
    c_ref[i] = 0;
1113
  }
1114

1115
  VarHandle i("i", kInt);
1116
  auto expr = For::make(
1117
      i,
1118
      0,
1119
      N,
1120
      c.store(
1121
          {i},
1122
          CompareSelect::make(
1123
              a.load(i), b.load(i), CompareSelectOperation::kLT)));
1124

1125
  LLVMCodeGen cg(expr, {a, b, c});
1126

1127
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1128
  ASSERT_EQ(cg.value<int>(args), 0);
1129

1130
  ASSERT_EQ(a_buffer.size(), N);
1131
  ASSERT_EQ(b_buffer.size(), N);
1132
  ASSERT_EQ(c_buffer.size(), N);
1133

1134
  assertAllEqual(b_buffer, uint8_t(128));
1135
  for (const auto i : c10::irange(N)) {
1136
    ASSERT_EQ(c_ref[i], c_buffer[i]);
1137
  }
1138
}
1139

1140
TEST(LLVM, CompareSelectByteLE) {
1141
  constexpr int N = 1024;
1142
  BufHandle a("A", {N}, kByte);
1143
  BufHandle b("B", {N}, kByte);
1144
  BufHandle c("C", {N}, kInt);
1145
  std::vector<uint8_t> a_buffer(N, 0);
1146
  std::vector<uint8_t> b_buffer(N, 128);
1147
  std::vector<int> c_buffer(N, 0);
1148
  std::vector<int> c_ref(N, 1);
1149

1150
  VarHandle i("i", kInt);
1151
  auto expr = For::make(
1152
      i,
1153
      0,
1154
      N,
1155
      c.store(
1156
          {i},
1157
          CompareSelect::make(
1158
              a.load(i), b.load(i), CompareSelectOperation::kLE)));
1159

1160
  LLVMCodeGen cg(expr, {a, b, c});
1161

1162
  std::vector<void*> args({a_buffer.data(), b_buffer.data(), c_buffer.data()});
1163
  ASSERT_EQ(cg.value<int>(args), 0);
1164

1165
  ASSERT_EQ(a_buffer.size(), N);
1166
  ASSERT_EQ(b_buffer.size(), N);
1167
  ASSERT_EQ(c_buffer.size(), N);
1168

1169
  assertAllEqual(b_buffer, uint8_t(128));
1170
  for (const auto i : c10::irange(N)) {
1171
    ASSERT_EQ(c_ref[i], c_buffer[i]);
1172
  }
1173
}
1174

1175
TEST(LLVM, StoreFloat) {
1176
  BufHandle result("result", {1}, kFloat);
1177
  std::vector<float> result_buffer = {0.0f};
1178
  auto expr = result.store({0}, FloatImm::make(3.14f));
1179
  LLVMCodeGen cg(expr, {result});
1180
  std::vector<void*> args({result_buffer.data()});
1181
  ASSERT_EQ(cg.value<int>(args), 0);
1182
  ASSERT_EQ(result_buffer[0], 3.14f);
1183
}
1184

1185
TEST(LLVM, SimpleMath01) {
1186
  const int N = 1024;
1187
  Tensor tensor = Compute(
1188
      "f", {N}, [](const VarHandle& i) { return cast<float>(i * i + 1); });
1189
  LoopNest l({tensor});
1190
  StmtPtr stmt = l.root_stmt();
1191
  BufHandle f_buf(tensor.buf());
1192
  LLVMCodeGen cg(stmt, {f_buf});
1193

1194
  PaddedBuffer<float> f_v(N, "f_v");
1195
  std::vector<void*> args({f_v.data()});
1196
  int value = cg.value<int>(args);
1197
  ASSERT_EQ(value, 0);
1198
  PaddedBuffer<float> f_ref(N, "f_ref");
1199
  for (const auto i : c10::irange(N)) {
1200
    f_ref(i) = i * i + 1;
1201
  }
1202
  ExpectAllNear(f_v, f_ref, 1e-5);
1203
}
1204

1205
TEST(LLVM, ComputeMul) {
1206
  const int N = 1024;
1207
  BufHandle a("a", {N}, kFloat);
1208
  BufHandle b("b", {N}, kFloat);
1209
  Tensor c = Compute(
1210
      "c", {N}, [&](const VarHandle& i) { return a.load(i) * b.load(i); });
1211

1212
  BufHandle c_buf(c.buf());
1213
  LoopNest l({c});
1214
  StmtPtr s = l.root_stmt();
1215

1216
  LLVMCodeGen cg(s, {a, b, c_buf});
1217

1218
  std::vector<float> a_vec(N, 21.0f);
1219
  std::vector<float> b_vec(N, 2.0f);
1220
  std::vector<float> c_vec(N, 0.0f);
1221
  std::vector<void*> args({a_vec.data(), b_vec.data(), c_vec.data()});
1222
  ASSERT_EQ(cg.value<int>(args), 0);
1223
  assertAllEqual(c_vec, 42.0f);
1224
}
1225

1226
TEST(LLVM, BroadcastAdd) {
1227
  const int M = 32;
1228
  const int N = 1024;
1229
  BufHandle a("a", {M, N}, kFloat);
1230
  BufHandle b("b", {N}, kFloat);
1231
  Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {
1232
    return a.load(i, j) + b.load(j);
1233
  });
1234

1235
  BufHandle c_buf(c.buf());
1236
  LoopNest l({c});
1237
  l.prepareForCodegen();
1238
  StmtPtr s = l.root_stmt();
1239

1240
  LLVMCodeGen cg(s, {a, b, c_buf});
1241

1242
  std::vector<float> av(M * N);
1243
  std::iota(av.begin(), av.end(), 0);
1244
  std::vector<float> bv(N);
1245
  std::iota(bv.begin(), bv.end(), 0);
1246
  std::vector<float> cv(M * N, 0);
1247
  std::vector<void*> args({av.data(), bv.data(), cv.data()});
1248
  ASSERT_EQ(cg.value<int>(args), 0);
1249

1250
  for (const auto i : c10::irange(M)) {
1251
    for (const auto j : c10::irange(N)) {
1252
      ASSERT_EQ(cv[i * N + j], av[i * N + j] + bv[j]);
1253
    }
1254
  }
1255
}
1256

1257
TEST(LLVM, BitwiseOps) {
1258
  auto a = IntImm::make(59);
1259
  auto b = IntImm::make(11);
1260
  auto c = IntImm::make(101);
1261
  auto d = IntImm::make(2);
1262

1263
  ExprHandle f = (((a ^ (b << 1)) & c) >> 2) | d;
1264
  LLVMExprEval cg(f);
1265

1266
  ASSERT_EQ(cg.value<int>(), 11);
1267
}
1268

1269
TEST(LLVM, ArithmeticRightShift) {
1270
  auto a = CharImm::make(-4);
1271
  auto b = CharImm::make(1);
1272
  ExprHandle f = a >> b;
1273
  LLVMExprEval cg(f);
1274
  ASSERT_EQ(cg.value<int8_t>(), -2);
1275
}
1276

1277
TEST(LLVM, LogicalRightShift) {
1278
  auto a = ByteImm::make(0xfc);
1279
  auto b = ByteImm::make(1);
1280
  ExprHandle f = a >> b;
1281
  LLVMExprEval cg(f);
1282
  ASSERT_EQ(cg.value<uint8_t>(), 0x7e);
1283
}
1284

1285
TEST(LLVM, DynamicShapeAdd) {
1286
  auto testWithSize = [](int32_t size) {
1287
    VarHandle n("n", kInt);
1288
    BufHandle a("a", {n}, kFloat);
1289
    BufHandle b("b", {n}, kFloat);
1290
    BufHandle c("c", {n}, kFloat);
1291
    VarHandle i("i", kInt);
1292
    StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
1293
    std::vector<float> aData(size, 1.0f);
1294
    std::vector<float> bData(size, 2.0f);
1295
    std::vector<float> cData(size, 0.0f);
1296
    LLVMCodeGen cg(s, {a, b, c, n});
1297
    std::vector<void*> args({aData.data(), bData.data(), cData.data(), &size});
1298
    cg.value<float>(args);
1299
    ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
1300
  };
1301
  testWithSize(1);
1302
  testWithSize(16);
1303
  testWithSize(37);
1304
}
1305

1306
TEST(LLVM, BindDynamicShapeAdd) {
1307
  auto testWithSize = [](int32_t size) {
1308
    VarHandle n("n", kInt);
1309
    BufHandle a("a", {n}, kFloat);
1310
    BufHandle b("b", {n}, kFloat);
1311
    BufHandle c("c", {n}, kFloat);
1312
    VarHandle i("i", kInt);
1313
    StmtPtr s = For::make(i, 0, n, c.store({i}, a.load(i) + b.load(i)));
1314
    std::vector<float> aData(size, 1.0f);
1315
    std::vector<float> bData(size, 2.0f);
1316
    std::vector<float> cData(size, 0.0f);
1317
    LLVMCodeGen cg(s, {a, b, c, n});
1318
    cg.call({aData, bData, cData, size});
1319
    ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
1320
  };
1321
  testWithSize(1);
1322
  testWithSize(16);
1323
  testWithSize(37);
1324
}
1325

1326
TEST(LLVM, TensorDynamicShapeAdd) {
1327
  auto testWithSize = [](int32_t size) {
1328
    VarHandle n("n", kInt);
1329
    BufHandle a("a", {n}, kFloat);
1330
    BufHandle b("b", {n}, kFloat);
1331
    Tensor c = Compute(
1332
        "c", {n}, [&](const VarHandle& i) { return a.load(i) + b.load(i); });
1333
    LoopNest l({c});
1334
    StmtPtr s = l.root_stmt();
1335
    LLVMCodeGen cg(s, {a, b, c, n});
1336
    std::vector<float> aData(size, 1.0f);
1337
    std::vector<float> bData(size, 2.0f);
1338
    std::vector<float> cData(size, 0.0f);
1339
    cg.call({aData, bData, cData, size});
1340
    ExpectAllNear(cData, std::vector<float>(size, 3.0f), 1e-7);
1341
  };
1342
  testWithSize(1);
1343
  testWithSize(16);
1344
  testWithSize(37);
1345
}
1346

1347
TEST(LLVM, DynamicShape2D) {
1348
  auto testWithSize = [](int32_t M, int32_t N) {
1349
    VarHandle m("m", kInt);
1350
    VarHandle n("n", kInt);
1351
    BufHandle a("a", {m, n}, kFloat);
1352
    BufHandle b("b", {m, n}, kFloat);
1353
    Tensor c =
1354
        Compute("c", {m, n}, [&](const VarHandle& i, const VarHandle& j) {
1355
          return a.load(i, j) + b.load(i, j);
1356
        });
1357
    LoopNest l({c});
1358
    l.prepareForCodegen();
1359
    StmtPtr s = l.root_stmt();
1360
    LLVMCodeGen cg(s, {a, b, c, m, n});
1361
    std::vector<float> aData(M * N, 1.0f);
1362
    std::vector<float> bData(M * N, 2.0f);
1363
    std::vector<float> cData(M * N, 0.0f);
1364
    cg.call({aData, bData, cData, M, N});
1365
    ExpectAllNear(cData, std::vector<float>(M * N, 3.0f), 1e-7);
1366
  };
1367
  testWithSize(1, 8);
1368
  testWithSize(16, 32);
1369
  testWithSize(37, 11);
1370
}
1371

1372
TEST(LLVM, EmptyStmt) {
1373
  StmtPtr s = alloc<Block>(std::vector<StmtPtr>({}));
1374

1375
  LLVMCodeGen cg(s, {});
1376
  cg.call({});
1377
  // Just don't crash.
1378
}
1379

1380
TEST(LLVM, EliminatedStmt) {
1381
  BufHandle a("a", {1}, kFloat);
1382

1383
  Tensor c = Compute("c", {0}, [&](const VarHandle& m) { return m; });
1384

1385
  LoopNest l({c});
1386
  l.prepareForCodegen();
1387
  StmtPtr s = l.root_stmt();
1388
  s = IRSimplifier::simplify(s);
1389
  LLVMCodeGen cg(s, {a, c});
1390
  std::vector<float> aData(1, 1.0f);
1391
  std::vector<float> cData(0, 0.0f);
1392
  cg.call({aData, cData});
1393
}
1394

1395
TEST(LLVM, SimpleReduction) {
1396
  int M = 128;
1397
  int N = 64;
1398

1399
  BufHandle a("a", {1, M, N}, kFloat);
1400

1401
  Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
1402
  LoopNest loop({b});
1403

1404
  loop.prepareForCodegen();
1405
  StmtPtr s = loop.root_stmt();
1406
  s = IRSimplifier::simplify(s);
1407

1408
  LLVMCodeGen cg(s, {a, b});
1409

1410
  PaddedBuffer<float> a_v(1, M, N, "a_v");
1411
  PaddedBuffer<float> b_v(1, "b_v");
1412
  PaddedBuffer<float> b_ref(1, "b_ref");
1413

1414
  b_ref(0) = 0;
1415
  for (const auto i : c10::irange(M)) {
1416
    for (const auto j : c10::irange(N)) {
1417
      int v = i + j;
1418
      a_v(0, i, j) = v;
1419
      b_ref(0) += v;
1420
    }
1421
  }
1422

1423
  cg.call({a_v, b_v});
1424

1425
  ExpectAllNear(b_v, b_ref, 1e-5);
1426
}
1427

1428
TEST(LLVM, RFactorReduction) {
1429
  int M = 128;
1430
  int N = 64;
1431

1432
  BufHandle a("a", {1, M, N}, kFloat);
1433

1434
  Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
1435
  LoopNest loop({b});
1436

1437
  std::vector<ForPtr> loops = loop.getLoopStmtsFor(b);
1438
  ForPtr loop_m = loops.at(1);
1439
  ForPtr loop_n = loops.at(2);
1440
  loop.reorderAxis(loop_m, loop_n);
1441

1442
  loops = loop.getLoopStmtsFor(b);
1443
  loop_m = loops.at(2);
1444
  loop_n = loops.at(1);
1445
  auto b_body = loop.getAllWritesToBuf(b.buf())[1];
1446
  ASSERT_TRUE(loop.rfactor(b_body, loop_n));
1447

1448
  loop.prepareForCodegen();
1449
  StmtPtr s = loop.root_stmt();
1450
  s = IRSimplifier::simplify(s);
1451

1452
  LLVMCodeGen cg(s, {a, b});
1453

1454
  PaddedBuffer<float> a_v(1, M, N, "a_v");
1455
  PaddedBuffer<float> b_v(1, "b_v");
1456
  PaddedBuffer<float> b_ref(1, "b_ref");
1457

1458
  b_ref(0) = 0;
1459
  for (const auto i : c10::irange(M)) {
1460
    for (const auto j : c10::irange(N)) {
1461
      int v = i + j;
1462
      a_v(0, i, j) = v;
1463
      b_ref(0) += v;
1464
    }
1465
  }
1466

1467
  cg.call({a_v, b_v});
1468

1469
  ExpectAllNear(b_v, b_ref, 1e-5);
1470
}
1471

1472
TEST(LLVM, RFactorVectorizedReduction) {
1473
  int M = 128;
1474
  int N = 64;
1475

1476
  BufHandle a("a", {1, M, N}, kFloat);
1477

1478
  Tensor b = Reduce("sum", {1}, Sum(), a, {M, N});
1479
  LoopNest loopnest({b});
1480
  std::vector<ForPtr> loops = loopnest.getLoopStmtsFor(b);
1481
  // Reorder n and m loops
1482
  loopnest.reorderAxis(loops.at(1), loops.at(2));
1483
  auto b_body = loopnest.getAllWritesToBuf(b.buf()).at(1);
1484
  auto all_loops = loopnest.getAllLoopNestsWritingToBuf(b.buf());
1485
  ASSERT_TRUE(all_loops.size() == 2 && all_loops[1].size() == 3);
1486
  ASSERT_TRUE(loopnest.rfactor(b_body, all_loops[1][1]));
1487
  auto distributed_loops = loopnest.distributeLoop(all_loops[1][1]);
1488

1489
  // Vectorize initializer of rfac_buf
1490
  ASSERT_TRUE(LoopNest::vectorize(distributed_loops[0]));
1491
  // Vectorize producer of rfac_buf
1492
  ASSERT_TRUE(LoopNest::vectorize(distributed_loops[1]));
1493
  loopnest.simplify();
1494

1495
  loopnest.prepareForCodegen();
1496

1497
  StmtPtr s = IRSimplifier::simplify(loopnest.root_stmt());
1498
  LLVMCodeGen cg(s, {a, b});
1499

1500
  PaddedBuffer<float> a_v(1, M, N, "a_v");
1501
  PaddedBuffer<float> b_v(1, "b_v");
1502
  PaddedBuffer<float> b_ref(1, "b_ref");
1503

1504
  b_ref(0) = 0;
1505
  for (const auto i : c10::irange(M)) {
1506
    for (const auto j : c10::irange(N)) {
1507
      int v = i + j;
1508
      a_v(0, i, j) = v;
1509
      b_ref(0) += v;
1510
    }
1511
  }
1512

1513
  cg.call({a_v, b_v});
1514

1515
  ExpectAllNear(b_v, b_ref, 1e-5);
1516
}
1517

1518
template <bool outer, bool inner>
1519
static void testSimpleParallel() {
1520
  // Compute a simple operation, and try all loop-axis combination to be
1521
  // parallel or sequential.
1522
  const int M = 4;
1523
  const int N = 6;
1524
  Tensor f = Compute("f", {M, N}, [](const VarHandle& m, const VarHandle& n) {
1525
    return cast<float>(m + n);
1526
  });
1527
  LoopNest loop_nest({f});
1528
  auto const& loops = loop_nest.getLoopStmtsFor(f);
1529
  ForPtr m = loops[0];
1530
  ForPtr n = loops[1];
1531
  if (outer) {
1532
    m->set_parallel();
1533
  }
1534
  if (inner) {
1535
    n->set_parallel();
1536
  }
1537
  loop_nest.prepareForCodegen();
1538
  StmtPtr stmt = loop_nest.root_stmt();
1539
  LLVMCodeGen cg(stmt, {f});
1540

1541
  PaddedBuffer<float> f_v(M, N, "f_v");
1542
  std::vector<void*> args({f_v.data()});
1543
  int value = cg.value<int>(args);
1544
  ASSERT_EQ(value, 0);
1545
  PaddedBuffer<float> f_ref(M, N, "f_ref");
1546
  for (const auto m : c10::irange(M)) {
1547
    for (const auto n : c10::irange(N)) {
1548
      f_ref(m, n) = m + n;
1549
    }
1550
  }
1551
  ExpectAllNear(f_v, f_ref, 1e-5);
1552
}
1553

1554
TEST(LLVM, SimpleParallelSS) {
1555
  testSimpleParallel<false, false>();
1556
}
1557
TEST(LLVM, SimpleParallelSP) {
1558
  testSimpleParallel<false, true>();
1559
}
1560
TEST(LLVM, SimpleParallelPS) {
1561
  testSimpleParallel<true, false>();
1562
}
1563
TEST(LLVM, SimpleParallelPP) {
1564
  testSimpleParallel<true, true>();
1565
}
1566

1567
TEST(LLVM, CompositeParallel) {
1568
  int loop_count = 6;
1569
  int test_count = 1 << loop_count;
1570
  // Compute a composite operation, and try all loop-axis combination to be
1571
  // parallel or sequential.
1572
  for (const auto test_cfg : c10::irange(test_count)) {
1573
    int M = 5;
1574
    int N = 7;
1575
    Tensor t1 = Compute("t1", {M}, [](const VarHandle& m) { return m + 1.f; });
1576
    Tensor t2 = Compute("t2", {N}, [](const VarHandle& n) { return n + 2.f; });
1577
    Tensor t3 =
1578
        Compute("t3", {M, N}, [=](const VarHandle& m, const VarHandle& n) {
1579
          return t1.load(m) * t2.load(n);
1580
        });
1581
    Tensor t4 =
1582
        Compute("t4", {M, N}, [=](const VarHandle& m, const VarHandle& n) {
1583
          return t3.load(m, n) + m + n;
1584
        });
1585
    LoopNest loop_nest({t4}, {t1, t2, t3, t4});
1586
    std::vector<ForPtr> loop_list;
1587
    {
1588
      auto const& loops = loop_nest.getLoopStmtsFor(t1);
1589
      loop_list.push_back(loops[0]);
1590
    }
1591
    {
1592
      auto const& loops = loop_nest.getLoopStmtsFor(t2);
1593
      loop_list.push_back(loops[0]);
1594
    }
1595
    {
1596
      auto const& loops = loop_nest.getLoopStmtsFor(t3);
1597
      loop_list.push_back(loops[0]);
1598
      loop_list.push_back(loops[1]);
1599
    }
1600
    {
1601
      auto const& loops = loop_nest.getLoopStmtsFor(t4);
1602
      loop_list.push_back(loops[0]);
1603
      loop_list.push_back(loops[1]);
1604
    }
1605
    ASSERT_EQ(loop_list.size(), loop_count);
1606
    for (const auto i : c10::irange(loop_count)) {
1607
      if (test_cfg & (1 << i)) {
1608
        loop_list[i]->set_parallel();
1609
      }
1610
    }
1611
    loop_nest.prepareForCodegen();
1612
    StmtPtr stmt = loop_nest.root_stmt();
1613
    LLVMCodeGen cg(stmt, {t4});
1614

1615
    PaddedBuffer<float> t4_v(M, N, "t4_v");
1616
    std::vector<void*> args({t4_v.data()});
1617
    int value = cg.value<int>(args);
1618
    ASSERT_EQ(value, 0);
1619
    PaddedBuffer<float> t4_ref(M, N, "t4_ref");
1620
    for (const auto m : c10::irange(M)) {
1621
      for (const auto n : c10::irange(N)) {
1622
        t4_ref(m, n) = (m + 1) * (n + 2) + m + n;
1623
      }
1624
    }
1625
    ExpectAllNear(t4_v, t4_ref, 1e-5);
1626
  }
1627
}
1628

1629
TEST(LLVM, VectorizedGEMM) {
1630
  int M = 32;
1631
  int N = 32;
1632
  int K = 48;
1633

1634
  BufHandle AP("A", {M, K}, kFloat);
1635
  BufHandle BP("B", {K, N}, kFloat);
1636
  Tensor CT = Reduce(
1637
      "gemm",
1638
      {M, N},
1639
      Sum(),
1640
      [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
1641
        return AP.load(m, k) * BP.load(k, n);
1642
      },
1643
      {K});
1644
  LoopNest loop({CT});
1645

1646
  {
1647
    auto const& loops = loop.getLoopStmtsFor(CT);
1648
    ForPtr m = loops[0];
1649
    loop.splitWithMask(m, 16);
1650
  }
1651
  {
1652
    auto const& loops = loop.getLoopStmtsFor(CT);
1653
    ForPtr n = loops[2];
1654
    loop.splitWithMask(n, 16);
1655
  }
1656
  // mo, mi, no, ni, k ->
1657
  // mo, no, mi, ni, k
1658
  {
1659
    auto const& loops = loop.getLoopStmtsFor(CT);
1660
    ForPtr mi = loops[1];
1661
    ForPtr no = loops[2];
1662
    loop.reorderAxis(mi, no);
1663
  }
1664
  // mo, no, mi, ni, k ->
1665
  // mo, no, mi, k, ni
1666
  {
1667
    auto const& loops = loop.getLoopStmtsFor(CT);
1668
    ForPtr ni = loops[3];
1669
    ForPtr k = loops[4];
1670
    loop.reorderAxis(ni, k);
1671
  }
1672
  // mo, no, mi, k, ni ->
1673
  // mo, no, k, mi, ni
1674
  {
1675
    auto const& loops = loop.getLoopStmtsFor(CT);
1676
    ForPtr mi = loops[2];
1677
    ForPtr k = loops[3];
1678
    loop.reorderAxis(mi, k);
1679
  }
1680
  {
1681
    auto loops = NodeFinder<For>::find(loop.root_stmt());
1682
    ASSERT_TRUE(LoopNest::vectorize(loops[3]));
1683
    ASSERT_TRUE(LoopNest::vectorize(loops.back()));
1684
  }
1685

1686
  loop.prepareForCodegen();
1687

1688
  StmtPtr s = loop.root_stmt();
1689
  s = IRSimplifier::simplify(s);
1690
  LLVMCodeGen cg(s, {AP, BP, CT});
1691

1692
  PaddedBuffer<float> a_v(M, K, "a_v");
1693
  PaddedBuffer<float> b_v(K, N, "b_v");
1694
  PaddedBuffer<float> c_v(M, N, "c_v");
1695
  PaddedBuffer<float> c_ref(M, N, "c_ref");
1696

1697
  for (const auto m : c10::irange(M)) {
1698
    for (const auto n : c10::irange(N)) {
1699
      c_ref(m, n) = 0.f;
1700
      for (const auto k : c10::irange(K)) {
1701
        c_ref(m, n) += a_v(m, k) * b_v(k, n);
1702
      }
1703
    }
1704
  }
1705

1706
  cg.call({a_v, b_v, c_v});
1707

1708
  ExpectAllNear(c_v, c_ref, 1e-5);
1709
}
1710

1711
TEST(LLVM, CallRaw) {
1712
  const int M = 32;
1713
  VarHandle N("N", kInt);
1714
  BufHandle a("a", {M, N}, kFloat);
1715
  BufHandle b("b", {N}, kFloat);
1716
  Tensor c = Compute("c", {M, N}, [&](const VarHandle& i, const VarHandle& j) {
1717
    return a.load(i, j) + b.load(j);
1718
  });
1719

1720
  LoopNest l({c});
1721
  l.prepareForCodegen();
1722
  StmtPtr s = l.root_stmt();
1723

1724
  int32_t N_value = 1024;
1725
  std::vector<float> av(M * N_value);
1726
  std::iota(av.begin(), av.end(), 0);
1727
  std::vector<float> bv(N_value);
1728
  std::iota(bv.begin(), bv.end(), 0);
1729
  std::vector<float> cv(M * N_value, 0);
1730
  std::vector<void*> args({av.data(), bv.data(), cv.data(), &N_value});
1731

1732
  LLVMCodeGen cg(s, {a, b, BufHandle(c.buf()), N});
1733
  cg.call_raw(args);
1734

1735
  for (const auto i : c10::irange(M)) {
1736
    for (const auto j : c10::irange(N_value)) {
1737
      ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]);
1738
    }
1739
  }
1740

1741
  SimpleIREvaluator eval(s, {a, b, BufHandle(c.buf()), N});
1742
  eval.call_raw(args);
1743

1744
  for (const auto i : c10::irange(M)) {
1745
    for (const auto j : c10::irange(N_value)) {
1746
      ASSERT_EQ(cv[i * N_value + j], av[i * N_value + j] + bv[j]);
1747
    }
1748
  }
1749
}
1750

1751
TEST(LLVM, CustomTarget) {
1752
  constexpr int M = 16;
1753
  BufHandle a("a", {M}, kFloat);
1754
  BufHandle b("b", {M}, kFloat);
1755
  BufHandle c("c", {M}, kFloat);
1756
  Tensor d = Compute("d", {M}, [&](const VarHandle& m) {
1757
    return a.load(m) * b.load(m) + c.load(m);
1758
  });
1759
  LoopNest nest({d});
1760
  nest.prepareForCodegen();
1761
  auto cg = LLVMCodeGenBuilder(nest.root_stmt(), {a, b, c, d})
1762
                .triple("i686-elf")
1763
                .cpu("i386")
1764
                .build();
1765
  std::ostringstream ss;
1766
  ss << cg->getCodeText("asm");
1767
  torch::jit::testing::FileCheck()
1768
      .check("fadds")
1769
      ->check("fmuls")
1770
      ->check_not("vfmadd")
1771
      ->run(ss.str());
1772
}
1773

1774
TEST(LLVM, CodeGenKernelFuncName) {
1775
  BufHandle a("A", {1}, kInt);
1776
  BufHandle b("B", {1}, kInt);
1777
  std::vector<int32_t> a_buffer = {42};
1778
  std::vector<int32_t> b_buffer = {-11};
1779
  auto store = b.store({0}, a.load(0));
1780

1781
  {
1782
    LLVMCodeGen cg(store, {a, b});
1783
    // Check that the kernel function name used by LLVMCodeGen
1784
    // is not empty.
1785
    ASSERT_NE(cg.kernel_func_name(), "");
1786
  }
1787

1788
  {
1789
    LLVMCodeGen cg(store, {a, b}, at::kCPU, "new_func");
1790
    // Check that the kernel function name used by LLVMCodeGen
1791
    // is the one that was given above.
1792
    ASSERT_EQ(cg.kernel_func_name(), "new_func");
1793
  }
1794
}
1795

1796
} // namespace jit
1797
} // namespace torch
1798

1799
#endif // TORCH_ENABLE_LLVM
1800

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.