pytorch

Форк
0
/
test_reductions.cpp 
1934 строки · 51.0 Кб
1
#include <gtest/gtest.h>
2

3
#include <limits>
4
#include <memory>
5
#include <sstream>
6
#include <stdexcept>
7
#include <unordered_map>
8

9
#include <test/cpp/tensorexpr/test_base.h>
10

11
#include <c10/util/irange.h>
12
#include <test/cpp/tensorexpr/padded_buffer.h>
13
#include <torch/csrc/jit/tensorexpr/analysis.h>
14
#include <torch/csrc/jit/tensorexpr/eval.h>
15
#include <torch/csrc/jit/tensorexpr/ir.h>
16
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
17
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
18
#include <torch/csrc/jit/tensorexpr/loopnest.h>
19
#include <torch/csrc/jit/tensorexpr/tensor.h>
20
#include <torch/csrc/jit/testing/file_check.h>
21

22
namespace torch {
23
namespace jit {
24

25
using namespace torch::jit::tensorexpr;
26

27
TEST(Reductions, ReduceSum0D_1) {
28
  const int M = 10;
29

30
  BufHandle b("b", {M}, kFloat);
31
  std::vector<float> in(M);
32
  for (const auto j : c10::irange(M)) {
33
    in[j] = j;
34
  }
35

36
  std::vector<float> out(M, -1.f);
37

38
  Tensor c = Reduce("sum", {M}, Sum(), b, {});
39
  LoopNest loop({c});
40
  loop.prepareForCodegen();
41
  StmtPtr s = loop.root_stmt();
42
  s = IRSimplifier::simplify(s);
43

44
  SimpleIREvaluator cg(s, {b, c});
45

46
  cg.call({in, out});
47
  for (const auto i : c10::irange(M)) {
48
    ASSERT_EQ(out[i], in[i]);
49
  }
50
}
51

52
TEST(Reductions, ReduceSum0D_2) {
53
  BufHandle b("b", {}, kFloat);
54
  std::vector<float> in(1);
55
  in[0] = 77.7;
56

57
  std::vector<float> out(1, -1.f);
58

59
  Tensor c = Reduce("sum", {}, Sum(), b, {});
60
  LoopNest loop({c});
61
  loop.prepareForCodegen();
62
  StmtPtr s = loop.root_stmt();
63
  s = IRSimplifier::simplify(s);
64

65
  SimpleIREvaluator cg(s, {b, c});
66

67
  cg.call({in, out});
68
  ASSERT_EQ(out[0], in[0]);
69
}
70

71
// Sum an array to a single value.
72
TEST(Reductions, ReduceSum1D) {
73
  BufHandle b("b", {10}, kFloat);
74
  std::vector<float> in(10);
75
  for (const auto j : c10::irange(10)) {
76
    in[j] = j;
77
  }
78

79
  std::vector<float> out(1, -1.f);
80

81
  Tensor c = Reduce("sum", {}, Sum(), b, {10});
82
  LoopNest loop({c});
83
  loop.prepareForCodegen();
84
  StmtPtr s = loop.root_stmt();
85
  s = IRSimplifier::simplify(s);
86

87
  SimpleIREvaluator cg(s, {b, c});
88

89
  cg.call({in, out});
90
  ASSERT_EQ(out[0], 45);
91
}
92
// Sum a 2D tensor to a 1D tensor with dynamic shapes.
93
TEST(Reductions, ReduceSum2D) {
94
  const int M = 3;
95
  const int N = 7;
96

97
  VarHandle m("m", kInt);
98
  VarHandle n("n", kInt);
99

100
  BufHandle b("b", {m, n}, kFloat);
101
  std::vector<float> in(M * N);
102
  for (const auto i : c10::irange(M)) {
103
    for (const auto j : c10::irange(N)) {
104
      in[i * N + j] = j;
105
    }
106
  }
107

108
  std::vector<float> out(M, -1.f);
109

110
  Tensor c = Reduce("sum", {M}, Sum(), b, {N});
111
  LoopNest loop({c});
112
  loop.prepareForCodegen();
113
  StmtPtr s = loop.root_stmt();
114
  s = IRSimplifier::simplify(s);
115

116
  SimpleIREvaluator cg(s, {b, c, n, m});
117

118
  cg.call({in, out, 5, 7});
119

120
  float expected = 0;
121
  for (const auto i : c10::irange(N)) {
122
    // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
123
    expected += i;
124
  }
125

126
  for (const auto i : c10::irange(M)) {
127
    ASSERT_EQ(out[i], expected);
128
  }
129
}
130

131
// Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to
132
// check our work.
133
TEST(Reductions, ReduceSum3D) {
134
  const int M = 10;
135
  VarHandle m("m", kInt);
136

137
  BufHandle b("b", {2, 3, m}, kFloat);
138

139
  Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
140
  LoopNest loop({c});
141
  loop.prepareForCodegen();
142
  StmtPtr s = loop.root_stmt();
143
  s = IRSimplifier::simplify(s);
144

145
  SimpleIREvaluator cg(s, {b, c, m});
146

147
  std::vector<float> bData(2 * 3 * M, 0);
148
  std::vector<float> cData(2 * 3, 6.0f);
149
  std::vector<float> dData(2, 1.0f);
150
  std::vector<float> eData(2, 1.0f);
151

152
  for (int i = 0; i < 2 * 3; ++i) {
153
    for (const auto j : c10::irange(M)) {
154
      bData[i * M + j] = j;
155
    }
156
  }
157

158
  cg.call({bData, cData, M});
159
  float expected = 0;
160
  for (const auto i : c10::irange(M)) {
161
    // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
162
    expected += i;
163
  }
164

165
  for (int i = 0; i < 2 * 3; ++i) {
166
    ASSERT_EQ(cData[i], expected);
167
  }
168

169
  Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m});
170
  LoopNest loop2({d});
171
  loop2.prepareForCodegen();
172
  StmtPtr s2 = loop2.root_stmt();
173
  s2 = IRSimplifier::simplify(s2);
174

175
  SimpleIREvaluator cg2(s2, {b, d, m});
176
  cg2.call({bData, dData, M});
177

178
  // We're combining an additional dimension of 3, so the sum is 3x.
179
  expected = expected * 3;
180

181
  for (const auto i : c10::irange(2)) {
182
    ASSERT_EQ(dData[i], expected);
183
  }
184

185
  // This is the same as just reducing the original result across that axis.
186
  BufHandle c_buf(c.buf());
187
  Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3});
188
  LoopNest loop3({e});
189
  loop3.prepareForCodegen();
190
  StmtPtr s3 = loop3.root_stmt();
191
  s3 = IRSimplifier::simplify(s3);
192

193
  SimpleIREvaluator cg3(s3, {c, e});
194
  cg3.call({cData, eData});
195

196
  for (const auto i : c10::irange(2)) {
197
    ASSERT_EQ(eData[i], expected);
198
  }
199
}
200

201
// Sum a large (10 D) Tensor 5 dimensions in.
202
TEST(Reductions, ReduceSum10D) {
203
  BufHandle in_("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat);
204
  const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3;
205
  BufHandle out_("out_", {2, 3, 2, 3, 2}, kFloat);
206
  const int OutputSize = 2 * 3 * 2 * 3 * 2;
207

208
  std::vector<float> in(InputSize, 1.f);
209
  std::vector<float> out(OutputSize, -1.f);
210

211
  Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3});
212
  LoopNest loop({c});
213
  loop.prepareForCodegen();
214
  StmtPtr s = loop.root_stmt();
215
  s = IRSimplifier::simplify(s);
216

217
  SimpleIREvaluator cg(s, {in_, c});
218

219
  cg.call({in, out});
220

221
  // NOLINTNEXTLINE(bugprone-integer-division)
222
  float expected = InputSize / OutputSize;
223
  for (const auto i : c10::irange(OutputSize)) {
224
    ASSERT_EQ(out[i], expected);
225
  }
226
}
227

228
// Reduce via Mul rather than Add using a custom Reducer.
229
TEST(Reductions, ReduceProduct) {
230
  const int M = 4;
231
  const int N = 4;
232

233
  BufHandle b("b", {M, N}, kFloat);
234
  std::vector<float> in(M * N);
235
  for (const auto i : c10::irange(M)) {
236
    for (const auto j : c10::irange(N)) {
237
      in[i * N + j] = 2 + j;
238
    }
239
  }
240

241
  std::vector<float> out(M, -1.f);
242

243
  Reducer product(
244
      ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; });
245

246
  Tensor c = Reduce("product", {M}, product, b, {N});
247
  LoopNest loop({c});
248
  loop.prepareForCodegen();
249
  StmtPtr s = loop.root_stmt();
250
  s = IRSimplifier::simplify(s);
251

252
  SimpleIREvaluator cg(s, {b, c});
253

254
  cg.call({in, out});
255

256
  float expected = 1;
257
  for (const auto i : c10::irange(N)) {
258
    // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
259
    expected *= 2 + i;
260
  }
261

262
  for (const auto i : c10::irange(M)) {
263
    ASSERT_EQ(out[i], expected);
264
  }
265
}
266

267
// Maximum reductions.
268
TEST(Reductions, ReduceMax) {
269
  BufHandle in_("b", {10}, kFloat);
270

271
  std::vector<float> in(10);
272
  std::vector<float> out(1, -1.f);
273
  for (const auto j : c10::irange(10)) {
274
    in[j] = j;
275
  }
276

277
  Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10});
278

279
  LoopNest loop({dm1});
280
  loop.prepareForCodegen();
281
  StmtPtr s = loop.root_stmt();
282
  s = IRSimplifier::simplify(s);
283
  SimpleIREvaluator cg(s, {in_, dm1});
284

285
  cg.call({in, out});
286

287
  ASSERT_EQ(out[0], 9);
288

289
  BufHandle in2_("b", {2, 5}, kFloat);
290
  std::vector<float> out2(2, -1.f);
291

292
  Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5});
293

294
  LoopNest loop2({m2d});
295
  loop2.prepareForCodegen();
296
  s = loop2.root_stmt();
297
  s = IRSimplifier::simplify(s);
298

299
  SimpleIREvaluator cg2(s, {in2_, m2d});
300
  cg2.call({in, out2});
301

302
  ASSERT_EQ(out2[0], 4);
303
  ASSERT_EQ(out2[1], 9);
304
}
305

306
// Minimum reduction, with custom initialization.
307
TEST(Reductions, ReduceMinCustomInitializer) {
308
  VarHandle minInit("minInit", kFloat);
309
  BufHandle in_("b", {10}, kFloat);
310

311
  std::vector<float> in(10);
312
  std::vector<float> out(1, -1.f);
313
  for (const auto j : c10::irange(10)) {
314
    in[j] = 10 + j;
315
  }
316

317
  Tensor min = Reduce(
318
      "min",
319
      {},
320
      Minimum(ExprHandle(minInit)),
321
      [&](ParameterList& v) { return in_.load(v); },
322
      {10});
323

324
  LoopNest loop({min});
325
  loop.prepareForCodegen();
326
  StmtPtr s = loop.root_stmt();
327
  s = IRSimplifier::simplify(s);
328

329
  SimpleIREvaluator cg(s, {in_, min, minInit});
330

331
  // Works normally (note that out data starts lower than the correct
332
  // minimum).
333
  cg.call({in, out, std::numeric_limits<float>::max()});
334
  ASSERT_EQ(out[0], 10);
335

336
  // With an initalizer lower than the min, that's the min.
337
  cg.call({in, out, 5.f});
338
  ASSERT_EQ(out[0], 5);
339
}
340

341
// Example implementation of Any/All.
342
// TODO: this is very awkward without logical And/Or operators.
343
TEST(Reductions, ReduceAnyAll) {
344
  VarHandle searchValue("searchValue", kInt);
345
  BufHandle b("b", {4, 10}, kInt);
346

347
  Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) {
348
    return CompareSelect::make(a, 1, 1, b, kEQ);
349
  });
350

351
  Tensor any = Reduce(
352
      "anyEqual",
353
      {4},
354
      anyEqSV,
355
      [&](const auto& i, const auto& j) {
356
        return CompareSelect::make(b.load(i, j), searchValue, kEQ);
357
      },
358
      {10});
359

360
  LoopNest loop({any});
361
  loop.prepareForCodegen();
362
  StmtPtr s = loop.root_stmt();
363
  s = IRSimplifier::simplify(s);
364

365
  SimpleIREvaluator cg(s, {b, any, searchValue});
366

367
  std::vector<int> in(40, 0);
368
  std::vector<int> out(4, 0);
369

370
  // input has 0-39 in 4 rows.
371
  for (const auto i : c10::irange(40)) {
372
    in[i] = i;
373
  }
374
  cg.call({in, out, 1});
375

376
  // only the first row has 1
377
  ASSERT_EQ(out[0], 1);
378
  ASSERT_EQ(out[1], 0);
379
  ASSERT_EQ(out[2], 0);
380
  ASSERT_EQ(out[3], 0);
381

382
  cg.call({in, out, 15});
383

384
  // 15 in the 3rd row
385
  ASSERT_EQ(out[0], 0);
386
  ASSERT_EQ(out[1], 1);
387
  ASSERT_EQ(out[2], 0);
388
  ASSERT_EQ(out[3], 0);
389

390
  Reducer allGTSV(ExprHandle(1), [](ExprHandle a, ExprHandle b) {
391
    return CompareSelect::make(a, 0, 0, b, kEQ);
392
  });
393

394
  Tensor allGreaterThan = Reduce(
395
      "allGreaterThan",
396
      {4},
397
      allGTSV,
398
      [&](const auto& i, const auto& j) {
399
        return CompareSelect::make(b.load(i, j), searchValue, kGT);
400
      },
401
      {10});
402

403
  LoopNest loop2({allGreaterThan});
404
  loop2.prepareForCodegen();
405
  s = loop2.root_stmt();
406
  s = IRSimplifier::simplify(s);
407

408
  SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue});
409

410
  cg2.call({in, out, 11});
411

412
  // 11 is in row 2.
413
  ASSERT_EQ(out[0], 0);
414
  ASSERT_EQ(out[1], 0);
415
  ASSERT_EQ(out[2], 1);
416
  ASSERT_EQ(out[3], 1);
417

418
  cg2.call({in, out, -3});
419

420
  // All are positive.
421
  ASSERT_EQ(out[0], 1);
422
  ASSERT_EQ(out[1], 1);
423
  ASSERT_EQ(out[2], 1);
424
  ASSERT_EQ(out[3], 1);
425
}
426

427
TEST(Reductions, ReduceMatmul2D) {
428
  BufHandle tA("tA", {3, 2}, kFloat);
429
  BufHandle tB("tB", {2, 3}, kFloat);
430

431
  std::vector<float> tA_(6);
432
  std::vector<float> tB_(6);
433

434
  std::vector<float> out(9, -1.f);
435
  for (const auto i : c10::irange(3)) {
436
    for (const auto j : c10::irange(2)) {
437
      tA_[i * 2 + j] = i * 2 + j;
438
      tB_[j * 3 + i] = i * 2 + j;
439
    }
440
  }
441

442
  Tensor mm = Reduce(
443
      "mm",
444
      {3, 3},
445
      Sum(),
446
      [&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
447
        return tA.load(m, k) * tB.load(k, n);
448
      },
449
      {2});
450

451
  LoopNest loop({mm});
452
  loop.prepareForCodegen();
453
  StmtPtr s = loop.root_stmt();
454
  s = IRSimplifier::simplify(s);
455

456
  SimpleIREvaluator cg(s, {tA, tB, mm});
457
  cg.call({tA_, tB_, out});
458

459
  std::vector<float> expected(
460
      {1.f, 3.f, 5.f, 3.f, 13.f, 23.f, 5.f, 23.f, 41.f});
461

462
  for (const auto i : c10::irange(9)) {
463
    ASSERT_EQ(out[i], expected[i]);
464
  }
465
}
466

467
TEST(Reductions, ReduceRfactorLike) {
468
  BufHandle in("in", {10, 10}, kFloat);
469
  std::vector<float> in_(100);
470
  for (const auto i : c10::irange(100)) {
471
    in_[i] = i;
472
  }
473
  std::vector<float> in_rf_(10, -2.f);
474
  std::vector<float> out(1, -1.f);
475

476
  Tensor l1 = Reduce("l1", {10}, Sum(), in, {10});
477
  BufHandle in_rf(l1.buf());
478

479
  Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10});
480

481
  LoopNest loop({l1, l2});
482
  loop.prepareForCodegen();
483
  StmtPtr s = loop.root_stmt();
484
  s = IRSimplifier::simplify(s);
485

486
  SimpleIREvaluator cg(s, {in, l1, l2});
487
  cg.call({in_, in_rf_, out});
488

489
  ASSERT_EQ(out[0], 99 * 50);
490
}
491

492
TEST(Reductions, ReduceAsProducer) {
493
  const int M = 10;
494
  VarHandle m("m", kInt);
495

496
  BufHandle a("a", {2, 3}, kFloat);
497
  BufHandle b("b", {2, 3, m}, kFloat);
498

499
  Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
500
  Tensor d =
501
      Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) {
502
        return c.load(l, n) * a.load(l, n);
503
      });
504
  LoopNest loop({d}, {c, d});
505
  loop.prepareForCodegen();
506
  StmtPtr s = loop.root_stmt();
507
  s = IRSimplifier::simplify(s);
508

509
  SimpleIREvaluator cg(s, {a, b, d, m});
510

511
  std::vector<float> aData(2 * 3, 0);
512
  std::vector<float> bData(2 * 3 * M, 0);
513
  std::vector<float> dData(2 * 3, 6.0f);
514

515
  for (int i = 0; i < 2 * 3; ++i) {
516
    aData[i] = 6 - i;
517
    for (const auto j : c10::irange(M)) {
518
      bData[i * M + j] = j;
519
    }
520
  }
521

522
  cg.call({aData, bData, dData, M});
523
  float expected = 0;
524
  for (const auto i : c10::irange(M)) {
525
    // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
526
    expected += i;
527
  }
528
  for (int i = 0; i < 2 * 3; ++i) {
529
    ASSERT_EQ(dData[i], expected * (6 - i));
530
  }
531
}
532

533
TEST(Reductions, ReduceAsConsumer) {
534
  const int M = 10;
535
  VarHandle m("m", kInt);
536

537
  BufHandle a("a", {2, 3, m}, kFloat);
538
  BufHandle b("b", {2, 3, m}, kFloat);
539

540
  Tensor c = Compute(
541
      "scale",
542
      {2, 3, m},
543
      [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
544
        return b.load(l, n, m) * a.load(l, n, m);
545
      });
546
  Tensor d = Reduce("sum", {2}, Sum(), c, {3, m});
547
  LoopNest loop({d}, {c, d});
548
  loop.prepareForCodegen();
549
  StmtPtr s = loop.root_stmt();
550
  s = IRSimplifier::simplify(s);
551

552
  SimpleIREvaluator cg(s, {a, b, d, m});
553

554
  std::vector<float> aData(2 * 3 * M, 0);
555
  std::vector<float> bData(2 * 3 * M, 0);
556
  std::vector<float> dData(2, 6.0f);
557

558
  for (int i = 0; i < 2 * 3; ++i) {
559
    for (const auto j : c10::irange(M)) {
560
      bData[i * M + j] = j + 1;
561
      aData[i * M + j] = 6 - i;
562
    }
563
  }
564

565
  cg.call({aData, bData, dData, M});
566
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
567
  float expected[2] = {0, 0};
568
  for (const auto i : c10::irange(2)) {
569
    for (const auto j : c10::irange(3)) {
570
      for (const auto k : c10::irange(M)) {
571
        // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
572
        expected[i] += (k + 1) * (6 - (i * 3 + j));
573
      }
574
    }
575
  }
576

577
  for (const auto i : c10::irange(2)) {
578
    ASSERT_EQ(dData[i], expected[i]);
579
  }
580
}
581

582
TEST(Reductions, SplitReduceAxis) {
583
  BufHandle in("in", {16, 8}, kFloat);
584

585
  std::vector<float> in_(16 * 8);
586
  for (const auto i : c10::irange(16)) {
587
    for (const auto j : c10::irange(8)) {
588
      in_[i * 8 + j] = i;
589
    }
590
  }
591
  std::vector<float> out(16, -1.f);
592

593
  Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
594
  LoopNest l({tensor});
595
  std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
596
  LoopNest::splitWithTail(loops[1], 2);
597

598
  l.prepareForCodegen();
599

600
  StmtPtr s = l.root_stmt();
601
  s = IRSimplifier::simplify(s);
602

603
  SimpleIREvaluator cg(s, {in, tensor});
604
  cg.call({in_, out});
605

606
  for (const auto i : c10::irange(16)) {
607
    ASSERT_EQ(out[i], i * 8);
608
  }
609
}
610

611
TEST(Reductions, SplitNonReduceAxis) {
612
  BufHandle in("in", {16, 8}, kFloat);
613

614
  std::vector<float> in_(16 * 8);
615
  for (const auto i : c10::irange(16)) {
616
    for (const auto j : c10::irange(8)) {
617
      in_[i * 8 + j] = i;
618
    }
619
  }
620
  std::vector<float> out(16, -1.f);
621
  Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
622
  LoopNest l({tensor});
623
  std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
624
  LoopNest::splitWithTail(loops[0], 2);
625
  LoopNest::splitWithTail(loops[0], 2);
626

627
  l.prepareForCodegen();
628

629
  StmtPtr s = l.root_stmt();
630
  s = IRSimplifier::simplify(s);
631

632
  SimpleIREvaluator cg(s, {in, tensor});
633
  cg.call({in_, out});
634

635
  for (const auto i : c10::irange(16)) {
636
    ASSERT_EQ(out[i], i * 8);
637
  }
638
}
639

640
TEST(Reductions, ReorderedReductionInitializer) {
641
  /* From the quip:
642
  for k in 0..1:  // blockIdx
643
    for m in 0..128:
644
      for n in 0..64: // threadIdx
645
        SumOp(c(k, n), 0, a(k, m, n), {m})
646
  */
647

648
  BufHandle in("in", {1, 12, 6}, kFloat);
649
  std::vector<float> in_(12 * 6, 1.f);
650

651
  Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6});
652
  LoopNest l_({tensor_});
653

654
  l_.prepareForCodegen();
655
  StmtPtr s_ = Stmt::clone(l_.root_stmt());
656
  s_ = IRSimplifier::simplify(s_);
657

658
  Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6});
659
  LoopNest l({tensor});
660

661
  auto loops = l.getLoopStmtsFor(tensor);
662
  loops[0]->set_gpu_block_index(0);
663
  loops[1]->set_gpu_thread_index(0);
664

665
  LoopNest::reorderAxis(loops[1], loops[2]);
666

667
  StmtPtr s = l.root_stmt();
668
  // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
669
  s = IRSimplifier::simplify(s);
670

671
  l.prepareForCodegen();
672

673
  s = l.root_stmt();
674
  s = IRSimplifier::simplify(s);
675

676
  std::vector<float> out1(16, -1.f);
677
  SimpleIREvaluator cg(s_, {in, tensor_});
678
  cg.call({in_, out1});
679

680
  std::vector<float> out2(16, -1.f);
681
  SimpleIREvaluator cg2(s, {in, tensor});
682
  cg2.call({in_, out2});
683

684
  for (const auto i : c10::irange(16)) {
685
    ASSERT_EQ(out1[i], out2[i]);
686
  }
687
}
688

689
TEST(Reductions, ReduceRfactor) {
690
  const int M = 10;
691
  const int N = 10;
692
  VarHandle m("m", kInt);
693
  VarHandle n("n", kInt);
694

695
  BufHandle b("b", {m, n}, kFloat);
696
  std::vector<float> in(M * N);
697
  for (int j = 0; j < M * N; ++j) {
698
    in[j] = j;
699
  }
700

701
  std::vector<float> out(1, -1.f);
702

703
  Tensor c = Reduce("sum", {}, Sum(), b, {m, n});
704
  LoopNest loop({c});
705
  std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
706
  auto c_body = loop.getAllWritesToBuf(c.buf())[1];
707
  ASSERT_TRUE(loop.rfactor(c_body, loops.at(0)));
708
  auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
709
  ASSERT_EQ(rc.size(), 2);
710
  loop.prepareForCodegen();
711
  StmtPtr s = loop.root_stmt();
712
  s = IRSimplifier::simplify(s);
713

714
  SimpleIREvaluator cg(s, {b, c, m, n});
715

716
  cg.call({in, out, M, N});
717
  ASSERT_EQ(out[0], 4950);
718
}
719

720
TEST(Reductions, Reduce3DRfactorInner) {
721
  const int M = 10;
722
  const int N = 10;
723
  const int K = 10;
724
  VarHandle m("m", kInt);
725
  VarHandle n("n", kInt);
726
  VarHandle k("k", kInt);
727

728
  BufHandle b("b", {m, n, k}, kFloat);
729
  std::vector<float> in(M * N * K);
730
  for (int j = 0; j < M * N * K; ++j) {
731
    in[j] = j;
732
  }
733

734
  std::vector<float> out(1, -1.f);
735

736
  Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
737
  LoopNest loop({c});
738
  std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
739
  auto c_body = loop.getAllWritesToBuf(c.buf())[1];
740
  ASSERT_FALSE(loop.rfactor(c_body, loops.at(2)));
741
  auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
742
  ASSERT_EQ(rc.size(), 1);
743
  loop.prepareForCodegen();
744
  StmtPtr s = loop.root_stmt();
745
  s = IRSimplifier::simplify(s);
746

747
  SimpleIREvaluator cg(s, {b, c, m, n, k});
748

749
  cg.call({in, out, M, N, K});
750
  ASSERT_EQ(out[0], 499500);
751
}
752

753
TEST(Reductions, Reduce3DRfactorOuter) {
754
  const int M = 10;
755
  const int N = 10;
756
  const int K = 10;
757
  VarHandle m("m", kInt);
758
  VarHandle n("n", kInt);
759
  VarHandle k("k", kInt);
760

761
  BufHandle b("b", {m, n, k}, kFloat);
762
  std::vector<float> in(M * N * K);
763
  for (int j = 0; j < M * N * K; ++j) {
764
    in[j] = j;
765
  }
766

767
  std::vector<float> out(1, -1.f);
768

769
  Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
770
  LoopNest loop({c});
771
  std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
772
  auto c_body = loop.getAllWritesToBuf(c.buf())[1];
773
  ASSERT_TRUE(loop.rfactor(c_body, loops.at(0)));
774
  auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
775
  ASSERT_EQ(rc.size(), 2);
776
  loop.prepareForCodegen();
777
  StmtPtr s = loop.root_stmt();
778
  s = IRSimplifier::simplify(s);
779

780
  SimpleIREvaluator cg(s, {b, c, m, n, k});
781
  cg.call({in, out, M, N, K});
782
  ASSERT_EQ(out[0], 499500);
783
}
784

785
TEST(Reductions, ReduceRepeatedInternalRfactor) {
786
  BufHandle in_("in_", {2, 3, 4, 5, 6}, kFloat);
787
  const int InputSize = 2 * 3 * 4 * 5 * 6;
788

789
  std::vector<float> in(InputSize, 1.f);
790
  std::vector<float> out(1, -1.f);
791
  std::vector<float> ref(1, -1.f);
792

793
  Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6});
794
  LoopNest orig_loop({c});
795

796
  // Try rfactoring N outer loops
797
  for (const auto rfac_number : c10::irange(1, 5)) {
798
    LoopNest refloop(orig_loop);
799
    LoopNest loop(orig_loop);
800
    refloop.prepareForCodegen();
801
    SimpleIREvaluator ref_cg(
802
        IRSimplifier::simplify(refloop.root_stmt()), {in_, c});
803
    ref_cg.call({in, ref});
804

805
    BufPtr tmp_buf = c.buf();
806

807
    for (const auto idx : c10::irange(rfac_number)) {
808
      auto reduce = loop.getAllWritesToBuf(tmp_buf)[1];
809
      ASSERT_TRUE(loop.rfactor(
810
          reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf));
811
    }
812

813
    loop.prepareForCodegen();
814
    StmtPtr s = loop.root_stmt();
815
    s = IRSimplifier::simplify(s);
816

817
    SimpleIREvaluator cg(s, {in_, c});
818
    cg.call({in, out});
819

820
    ASSERT_EQ(ref[0], out[0]);
821
  }
822
}
823

824
// Split a reduction axis with a tail loop.
825
TEST(Reductions, ReduceSplitTail) {
826
  const int M = 10;
827
  const int N = 10;
828
  const int K = 10;
829

830
  BufHandle b("b", {M, N, K}, kFloat);
831
  std::vector<float> in(M * N * K);
832
  for (int j = 0; j < M * N * K; ++j) {
833
    in[j] = j;
834
  }
835

836
  for (const auto i : c10::irange(3)) {
837
    std::vector<float> out(M, -1.f);
838

839
    Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
840
    LoopNest loop({c});
841
    std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
842
    LoopNest::splitWithTail(loops[i], 8);
843

844
    loop.prepareForCodegen();
845
    StmtPtr s = loop.root_stmt();
846
    s = IRSimplifier::simplify(s);
847

848
    SimpleIREvaluator cg(s, {b, c});
849

850
    cg.call({in, out});
851
    ASSERT_EQ(out[0], 4950);
852
  }
853
}
854

855
// Split a reduction axis cleanly so there is no tail loop.
856
TEST(Reductions, ReduceSplitNoTail) {
857
  const int M = 10;
858
  const int N = 10;
859
  const int K = 10;
860
  BufHandle b("b", {M, N, K}, kFloat);
861
  std::vector<float> in(M * N * K);
862
  for (int j = 0; j < M * N * K; ++j) {
863
    in[j] = j;
864
  }
865

866
  for (const auto i : c10::irange(3)) {
867
    std::vector<float> out(M, -1.f);
868

869
    Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
870
    LoopNest loop({c});
871
    std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
872
    LoopNest::splitWithTail(loops[i], 5);
873

874
    loop.prepareForCodegen();
875
    StmtPtr s = loop.root_stmt();
876
    s = IRSimplifier::simplify(s);
877

878
    SimpleIREvaluator cg(s, {b, c});
879

880
    cg.call({in, out});
881
    ASSERT_EQ(out[0], 4950);
882
  }
883
}
884

885
// Split a reduction axis with only a tail loop (the split loop will be size 0
886
// and eliminated out).
887
TEST(Reductions, ReduceOverSplitTail) {
888
  const int M = 10;
889
  const int N = 10;
890
  const int K = 10;
891

892
  BufHandle b("b", {M, N, K}, kFloat);
893
  std::vector<float> in(M * N * K);
894
  for (int j = 0; j < M * N * K; ++j) {
895
    in[j] = j;
896
  }
897

898
  for (const auto i : c10::irange(3)) {
899
    std::vector<float> out(M, -1.f);
900

901
    Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
902
    LoopNest loop({c});
903
    std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
904
    LoopNest::splitWithTail(loops[i], 16);
905

906
    loop.prepareForCodegen();
907
    StmtPtr s = loop.root_stmt();
908
    s = IRSimplifier::simplify(s);
909

910
    SimpleIREvaluator cg(s, {b, c});
911

912
    cg.call({in, out});
913
    ASSERT_EQ(out[0], 4950);
914
  }
915
}
916

917
// Split a reduction axis with a mask.
918
TEST(Reductions, ReduceSplitMask) {
919
  const int M = 10;
920
  const int N = 10;
921
  const int K = 10;
922

923
  BufHandle b("b", {M, N, K}, kFloat);
924
  std::vector<float> in(M * N * K);
925
  for (int j = 0; j < M * N * K; ++j) {
926
    in[j] = j;
927
  }
928

929
  for (const auto i : c10::irange(3)) {
930
    std::vector<float> out(M, -1.f);
931

932
    Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
933
    LoopNest loop({c});
934
    std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
935
    LoopNest::splitWithMask(loops[i], 8);
936

937
    loop.prepareForCodegen();
938
    StmtPtr s = loop.root_stmt();
939
    s = IRSimplifier::simplify(s);
940

941
    SimpleIREvaluator cg(s, {b, c});
942

943
    cg.call({in, out});
944
    ASSERT_EQ(out[0], 4950);
945
  }
946
}
947

948
// Split a reduction axis cleanly not requiring a mask.
949
TEST(Reductions, ReduceSplitNoMask) {
950
  const int M = 10;
951
  const int N = 10;
952
  const int K = 10;
953
  BufHandle b("b", {M, N, K}, kFloat);
954
  std::vector<float> in(M * N * K);
955
  for (int j = 0; j < M * N * K; ++j) {
956
    in[j] = j;
957
  }
958

959
  for (const auto i : c10::irange(3)) {
960
    std::vector<float> out(M, -1.f);
961

962
    Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
963
    LoopNest loop({c});
964
    std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
965
    LoopNest::splitWithMask(loops[i], 5);
966

967
    loop.prepareForCodegen();
968
    StmtPtr s = loop.root_stmt();
969
    s = IRSimplifier::simplify(s);
970

971
    SimpleIREvaluator cg(s, {b, c});
972

973
    cg.call({in, out});
974
    ASSERT_EQ(out[0], 4950);
975
  }
976
}
977

978
// Split a reduction axis with all logic in the mask.
979
TEST(Reductions, ReduceOverSplitMask) {
980
  const int M = 10;
981
  const int N = 10;
982
  const int K = 10;
983

984
  BufHandle b("b", {M, N, K}, kFloat);
985
  std::vector<float> in(M * N * K);
986
  for (int j = 0; j < M * N * K; ++j) {
987
    in[j] = j;
988
  }
989

990
  for (const auto i : c10::irange(3)) {
991
    std::vector<float> out(M, -1.f);
992

993
    Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
994
    LoopNest loop({c});
995
    std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
996
    LoopNest::splitWithMask(loops[i], 16);
997

998
    loop.prepareForCodegen();
999
    StmtPtr s = loop.root_stmt();
1000
    s = IRSimplifier::simplify(s);
1001

1002
    SimpleIREvaluator cg(s, {b, c});
1003

1004
    cg.call({in, out});
1005
    ASSERT_EQ(out[0], 4950);
1006
  }
1007
}
1008

1009
// Test an rfactor when there are two ReduceOps in the graph due to a
1010
// splitWithTail.
1011
TEST(Reductions, ReduceSplitRfactor) {
1012
  const int M = 2;
1013
  const int N = 10;
1014
  const int K = 10;
1015
  const int SPLIT_FACTOR = 4;
1016

1017
  BufHandle b("b", {M, N, K}, kFloat);
1018
  std::vector<float> in(M * N * K);
1019
  for (const auto m : c10::irange(M)) {
1020
    for (int j = 0; j < N * K; ++j) {
1021
      in[m * N * K + j] = j;
1022
    }
1023
  }
1024

1025
  std::vector<float> out(M, -1.f);
1026

1027
  Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
1028
  LoopNest loop({c});
1029
  std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1030
  LoopNest::splitWithTail(loops[2], SPLIT_FACTOR);
1031

1032
  auto c_body = loop.getAllWritesToBuf(c.buf())[2];
1033
  auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1034
  ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3);
1035
  LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]);
1036
  all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1037
  ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3);
1038
  ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1]));
1039
  loop.prepareForCodegen();
1040
  loop.simplify();
1041
  StmtPtr s = loop.root_stmt();
1042

1043
  SimpleIREvaluator cg(s, {b, c});
1044

1045
  cg.call({in, out});
1046
  for (const auto i : c10::irange(M)) {
1047
    (void)i; // Suppress unused variable warning
1048
    ASSERT_EQ(out[0], 4950);
1049
  }
1050
}
1051

1052
// Test an rfactor which ends up being eliminated since the total loop size is
1053
// smaller than the split factor.
1054
TEST(Reductions, ReduceOverSplitRfactor) {
1055
  const int N = 10;
1056
  const int K = 10;
1057
  const int SPLIT_FACTOR = 16;
1058

1059
  BufHandle b("b", {N, K}, kFloat);
1060
  std::vector<float> in(N * K);
1061
  for (int j = 0; j < N * K; ++j) {
1062
    in[j] = j;
1063
  }
1064

1065
  std::vector<float> out(1, -1.f);
1066

1067
  Tensor c = Reduce("sum", {}, Sum(), b, {N, K});
1068
  LoopNest loop({c});
1069
  std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1070
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1071
  ForPtr i, t;
1072
  LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t);
1073
  LoopNest::reorderAxis(loops[0], i);
1074

1075
  auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1076
  ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3);
1077
  auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1078
  ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0]));
1079
  LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]);
1080

1081
  loop.prepareForCodegen();
1082
  loop.simplify();
1083
  StmtPtr s = loop.root_stmt();
1084

1085
  SimpleIREvaluator cg(s, {b, c});
1086

1087
  cg.call({in, out});
1088
  ASSERT_EQ(out[0], 4950);
1089

1090
  std::ostringstream oss;
1091
  oss << *cg.stmt();
1092

1093
  // Check the IR to verify the rfactored reduce is eliminated.
1094
  // TODO: The alloc free should be eliminated here since it is size 0.
1095
  /*
1096
  const std::string& verification_pattern =
1097
      R"IR(
1098
# CHECK: Allocate(tmp_buf); // dtype=float, dims=[0]
1099
# CHECK: sum[0] = 0.f;
1100
# CHECK: for (int n = 0; n < 10; n++) {
1101
# CHECK:   for (int k_tail = 0; k_tail < 10; k_tail++) {
1102
# CHECK:     sum[0] = (sum[0]) + (b[k_tail + 10 * n]);
1103
# CHECK:   }
1104
# CHECK: }
1105
# CHECK: Free(tmp_buf);)IR";
1106
  */
1107
  // TODO: rfactor output is not consistent yet, will fix (@nickg).
1108
  // torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1109
}
1110

1111
TEST(Reductions, ReduceInlineReduction) {
1112
  const int M = 4;
1113
  const int N = 5;
1114
  const int K = 6;
1115

1116
  BufHandle a_buf("a", {M}, kFloat);
1117
  BufHandle b_buf("b", {M, N, K}, kFloat);
1118

1119
  Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K});
1120
  Tensor y = Compute(
1121
      "y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); });
1122

1123
  PaddedBuffer<float> a_v(M);
1124
  PaddedBuffer<float> b_v(M, N, K);
1125

1126
  for (const auto i : c10::irange(M)) {
1127
    a_v(i) = i * i;
1128
  }
1129
  for (const auto i : c10::irange(M)) {
1130
    for (const auto j : c10::irange(N)) {
1131
      for (const auto k : c10::irange(K)) {
1132
        b_v(i, j, k) = j * j * k;
1133
      }
1134
    }
1135
  }
1136

1137
  LoopNest l1({y}, {x, y});
1138
  // Cannot inline a reduction computation
1139
  ASSERT_FALSE(l1.computeInline(x.buf()));
1140
}
1141

1142
TEST(Reductions, ReduceInlineConsumer) {
1143
  const int M = 4;
1144
  const int N = 5;
1145
  const int K = 6;
1146

1147
  BufHandle a_buf("a", {M, N, K}, kFloat);
1148
  BufHandle b_buf("b", {M, N, K}, kFloat);
1149

1150
  Tensor x = Compute(
1151
      "x",
1152
      {M, N, K},
1153
      [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1154
        return a_buf.load(m, n, k) + b_buf.load(m, n, k);
1155
      });
1156
  Tensor y = Reduce("y", {M}, Sum(), x, {N, K});
1157

1158
  PaddedBuffer<float> a_v(M, N, K);
1159
  PaddedBuffer<float> b_v(M, N, K);
1160

1161
  for (const auto i : c10::irange(M)) {
1162
    for (const auto j : c10::irange(N)) {
1163
      for (const auto k : c10::irange(K)) {
1164
        a_v(i, j, k) = i * i + k;
1165
        b_v(i, j, k) = j * j + k;
1166
      }
1167
    }
1168
  }
1169

1170
  LoopNest l1({y}, {x, y});
1171
  LoopNest l2(l1);
1172
  l2.computeInline(x.buf());
1173

1174
  l1.prepareForCodegen();
1175
  l2.prepareForCodegen();
1176

1177
  StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1178
  StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1179

1180
  SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
1181
  SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
1182

1183
  PaddedBuffer<float> y_1(M);
1184
  PaddedBuffer<float> y_2(M);
1185

1186
  eval1(a_v, b_v, y_1);
1187
  eval2(a_v, b_v, y_2);
1188
  ExpectAllNear(y_1, y_2, 1e-5);
1189
  std::ostringstream oss1, oss2;
1190
  oss1 << *stmt1;
1191
  oss2 << *stmt2;
1192
  ASSERT_GT(oss1.str().size(), oss2.str().size());
1193
}
1194

1195
TEST(Reductions, ReduceInlineReducerInternal) {
1196
  const int M = 4;
1197
  const int N = 5;
1198
  const int K = 6;
1199

1200
  BufHandle a_buf("a", {M, N, K}, kFloat);
1201
  BufHandle b_buf("b", {M, N, K}, kFloat);
1202

1203
  Tensor x = Compute(
1204
      "x",
1205
      {M, N, K},
1206
      [&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1207
        return a_buf.load(m, n, k) + b_buf.load(m, n, k);
1208
      });
1209

1210
  Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) {
1211
    return Add::make(ExprHandle(1.f), Min::make(a, b, false));
1212
  });
1213
  Tensor y = Reduce("y", {M}, minimum, x, {N, K});
1214

1215
  PaddedBuffer<float> a_v(M, N, K);
1216
  PaddedBuffer<float> b_v(M, N, K);
1217

1218
  for (const auto i : c10::irange(M)) {
1219
    for (const auto j : c10::irange(N)) {
1220
      for (const auto k : c10::irange(K)) {
1221
        a_v(i, j, k) = i * i + k;
1222
        b_v(i, j, k) = j * j + k;
1223
      }
1224
    }
1225
  }
1226

1227
  LoopNest l1({y}, {x, y});
1228
  LoopNest l2(l1);
1229
  l2.computeInline(x.buf());
1230

1231
  l1.prepareForCodegen();
1232
  l2.prepareForCodegen();
1233

1234
  StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1235
  StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1236

1237
  SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
1238
  SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
1239

1240
  PaddedBuffer<float> y_1(M);
1241
  PaddedBuffer<float> y_2(M);
1242

1243
  eval1(a_v, b_v, y_1);
1244
  eval2(a_v, b_v, y_2);
1245
  ExpectAllNear(y_1, y_2, 1e-5);
1246
  std::ostringstream oss1, oss2;
1247
  oss1 << *stmt1;
1248
  oss2 << *stmt2;
1249
  ASSERT_GT(oss1.str().size(), oss2.str().size());
1250
}
1251

1252
TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
1253
  int L = 4;
1254
  int N = 3;
1255
  int M = 2;
1256

1257
  BufHandle a("a", {L, N, M}, kFloat);
1258
  BufHandle b("b", {L, N, M}, kFloat);
1259

1260
  Tensor c = Compute(
1261
      "scale",
1262
      {L, N, M},
1263
      [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1264
        return b.load(l, n, m) * a.load(l, n, m);
1265
      });
1266
  Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1267

1268
  Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1269
    return b.load(0, 0, l) * d.load(l);
1270
  });
1271

1272
  LoopNest l({e}, {c, d, e});
1273
  LoopNest l_before(l);
1274
  l_before.prepareForCodegen();
1275
  SimpleIREvaluator cg_before(
1276
      LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e});
1277

1278
  StmtPtr d_loop = l.getLoopStmtsFor(d)[0];
1279
  l.cacheAccesses(d.buf(), "d_local", d_loop);
1280
  l.prepareForCodegen();
1281

1282
  StmtPtr result =
1283
      LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1284
  SimpleIREvaluator cg_after(result, {a, b, e});
1285

1286
  std::ostringstream oss;
1287
  oss << *cg_after.stmt();
1288
  const std::string& expected_ir =
1289
      R"IR(
1290
#CHECK: Allocate(d_local); // dtype=float, dims=[4]
1291
#CHECK: for (int i_2
1292
#CHECK:   d_local[i_2] = 0.f
1293
#CHECK:   for (int
1294
#CHECK:     for (int
1295
#CHECK:       d_local[i_2] = (d_local[i_2]) + (scale[
1296
#CHECK:     }
1297
#CHECK:   }
1298
#CHECK: }
1299
#CHECK: for (int i_3
1300
#CHECK:   sum[i_3] = d_local[i_3]
1301
#CHECK: Free(d_local);
1302
#CHECK-NOT: d_local
1303
      )IR";
1304
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1305

1306
  PaddedBuffer<float> a_v(L, M, N, "a");
1307
  PaddedBuffer<float> b_v(L, M, N, "b");
1308
  PaddedBuffer<float> c_v(L, M, N, "c");
1309
  PaddedBuffer<float> d_v(L, "d");
1310
  PaddedBuffer<float> e_before(L, "e_before");
1311
  PaddedBuffer<float> e_after(L, "e_after");
1312

1313
  for (const auto l : c10::irange(L)) {
1314
    for (const auto m : c10::irange(M)) {
1315
      for (const auto n : c10::irange(N)) {
1316
        a_v(l, m, n) = at::randn({1}).item().to<float>();
1317
        b_v(l, m, n) = at::randn({1}).item().to<float>();
1318
      }
1319
    }
1320
  }
1321

1322
  cg_before.call({a_v, b_v, e_before});
1323
  cg_after.call({a_v, b_v, e_after});
1324

1325
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1326
  ExpectAllNear(e_before, e_after, 1e-5);
1327
}
1328

1329
TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
1330
  int L = 4;
1331
  int N = 3;
1332
  int M = 2;
1333

1334
  BufHandle a("a", {L, N, M}, kFloat);
1335
  BufHandle b("b", {L, N, M}, kFloat);
1336

1337
  Tensor c = Compute(
1338
      "scale",
1339
      {L, N, M},
1340
      [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1341
        return b.load(l, n, m) * a.load(l, n, m);
1342
      });
1343
  Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1344

1345
  Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1346
    return b.load(0, 0, l) * d.load(l);
1347
  });
1348

1349
  LoopNest l({e}, {c, d, e});
1350
  LoopNest l_before(l);
1351
  l_before.prepareForCodegen();
1352
  SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
1353

1354
  StmtPtr d_loop = l.getLoopStmtsFor(d)[1];
1355
  l.cacheAccesses(d.buf(), "d_local", d_loop);
1356
  l.prepareForCodegen();
1357

1358
  StmtPtr result =
1359
      LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1360
  SimpleIREvaluator cg_after(result, {a, b, e});
1361

1362
  std::ostringstream oss;
1363
  oss << *cg_after.stmt();
1364
  const std::string& expected_ir =
1365
      R"IR(
1366
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
1367
#CHECK: sum[i_1] = 0
1368
#CHECK: d_local[0] = sum[i_1]
1369
#CHECK: for (int j_1
1370
#CHECK:   for (int k_1
1371
#CHECK: d_local[0] = (d_local[0]) + (scale[
1372
#CHECK:   }
1373
#CHECK: }
1374
#CHECK: sum[i_1] = d_local[0]
1375
#CHECK: Free(d_local);
1376
#CHECK-NOT: d_local
1377
      )IR";
1378
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1379

1380
  PaddedBuffer<float> a_v(L, M, N, "a");
1381
  PaddedBuffer<float> b_v(L, M, N, "b");
1382
  PaddedBuffer<float> c_v(L, M, N, "c");
1383
  PaddedBuffer<float> d_v(L, "d");
1384
  PaddedBuffer<float> e_before(L, "e_before");
1385
  PaddedBuffer<float> e_after(L, "e_after");
1386

1387
  for (const auto l : c10::irange(L)) {
1388
    for (const auto m : c10::irange(M)) {
1389
      for (const auto n : c10::irange(N)) {
1390
        a_v(l, m, n) = at::randn({1}).item().to<float>();
1391
        b_v(l, m, n) = at::randn({1}).item().to<float>();
1392
      }
1393
    }
1394
  }
1395

1396
  cg_before.call({a_v, b_v, e_before});
1397
  cg_after.call({a_v, b_v, e_after});
1398

1399
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1400
  ExpectAllNear(e_before, e_after, 1e-5);
1401
}
1402

1403
TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
1404
  int L = 4;
1405
  int N = 3;
1406
  int M = 2;
1407

1408
  BufHandle a("a", {L, N, M}, kFloat);
1409
  BufHandle b("b", {L, N, M}, kFloat);
1410

1411
  Tensor c = Compute(
1412
      "scale",
1413
      {L, N, M},
1414
      [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1415
        return b.load(l, n, m) * a.load(l, n, m);
1416
      });
1417
  Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1418

1419
  Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1420
    return b.load(0, 0, l) * d.load(l);
1421
  });
1422

1423
  LoopNest l({e}, {c, d, e});
1424
  LoopNest l_before(l);
1425
  l_before.prepareForCodegen();
1426
  SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
1427

1428
  StmtPtr d_loop = l.getLoopStmtsFor(d)[2];
1429
  l.cacheAccesses(d.buf(), "d_local", d_loop);
1430
  l.prepareForCodegen();
1431

1432
  StmtPtr result =
1433
      LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1434
  SimpleIREvaluator cg_after(result, {a, b, e});
1435

1436
  std::ostringstream oss;
1437
  oss << *cg_after.stmt();
1438
  const std::string& expected_ir =
1439
      R"IR(
1440
#CHECK: Allocate(d_local); // dtype=float, dims=[1]
1441
#CHECK: sum[i_1] = 0
1442
#CHECK: for (int
1443
#CHECK:   d_local[0] = 0
1444
#CHECK:   for (int
1445
#CHECK:     d_local[0] = (d_local[0]) + (scale[
1446
#CHECK:   }
1447
#CHECK:   sum[i_1] = (sum[i_1]) + (d_local[0])
1448
#CHECK: }
1449
#CHECK: Free(d_local);
1450
#CHECK-NOT: d_local
1451
      )IR";
1452
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1453

1454
  PaddedBuffer<float> a_v(L, M, N, "a");
1455
  PaddedBuffer<float> b_v(L, M, N, "b");
1456
  PaddedBuffer<float> c_v(L, M, N, "c");
1457
  PaddedBuffer<float> d_v(L, "d");
1458
  PaddedBuffer<float> e_before(L, "e_before");
1459
  PaddedBuffer<float> e_after(L, "e_after");
1460

1461
  for (const auto l : c10::irange(L)) {
1462
    for (const auto m : c10::irange(M)) {
1463
      for (const auto n : c10::irange(N)) {
1464
        a_v(l, m, n) = at::randn({1}).item().to<float>();
1465
        b_v(l, m, n) = at::randn({1}).item().to<float>();
1466
      }
1467
    }
1468
  }
1469

1470
  cg_before.call({a_v, b_v, e_before});
1471
  cg_after.call({a_v, b_v, e_after});
1472

1473
  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1474
  ExpectAllNear(e_before, e_after, 1e-5);
1475
}
1476

1477
TEST(Reductions, ReductionCacheBodyAccess) {
1478
  BufHandle a("a", {24, 32, 12}, kFloat);
1479
  BufHandle b("b", {24, 32, 12}, kFloat);
1480

1481
  Tensor c = Compute(
1482
      "scale",
1483
      {24, 32, 12},
1484
      [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1485
        return b.load(l, n, m) * a.load(l, n, m);
1486
      });
1487
  Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1488

1489
  Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1490
    return b.load(0, 0, l) * d.load(l);
1491
  });
1492

1493
  LoopNest l({e}, {c, d, e});
1494

1495
  StmtPtr d_loop = l.getLoopStmtsFor(d)[1];
1496
  l.cacheAccesses(c.buf(), "scale_local", d_loop);
1497

1498
  l.prepareForCodegen();
1499
  StmtPtr result =
1500
      LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1501
  SimpleIREvaluator cg(result, {a, b, e});
1502

1503
  std::ostringstream oss;
1504
  oss << *cg.stmt();
1505
  const std::string& expected_ir =
1506
      R"IR(
1507
#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12]
1508
#CHECK: for (int j_1 = 0; j_1 < 32; j_1++) {
1509
#CHECK:   for (int k_1 = 0; k_1 < 12; k_1++) {
1510
#CHECK:     scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1];
1511
#CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]);
1512
#CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]);
1513
#CHECK: Free(scale_local);
1514
      )IR";
1515
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1516
}
1517

1518
TEST(Reductions, ReductionCacheConsumerAccess) {
1519
  BufHandle a("a", {24, 32, 12}, kFloat);
1520
  BufHandle b("b", {24, 32, 12}, kFloat);
1521

1522
  Tensor c = Compute(
1523
      "scale",
1524
      {24, 32, 12},
1525
      [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1526
        return b.load(l, n, m) * a.load(l, n, m);
1527
      });
1528
  Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1529

1530
  Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1531
    return b.load(0, 0, l) * d.load(l);
1532
  });
1533

1534
  LoopNest l({e}, {c, d, e});
1535

1536
  LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4);
1537

1538
  StmtPtr e_loop = l.getLoopStmtsFor(e)[1];
1539
  l.cacheAccesses(d.buf(), "sum_local", e_loop);
1540
  l.prepareForCodegen();
1541

1542
  StmtPtr result =
1543
      LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1544
  SimpleIREvaluator cg(result, {a, b, e});
1545

1546
  std::ostringstream oss;
1547
  oss << *cg.stmt();
1548
  const std::string& expected_ir =
1549
      R"IR(
1550
#CHECK: Alias(sum_local,scale);
1551
#CHECK: sum[i_1] = (sum[i_1]) + (scale[
1552
#CHECK: for (int j_2 = 0; j_2 < 4
1553
#CHECK:   sum_local[j_2] = sum[j_2 + 4 * i_2];
1554
#CHECK:   scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
1555
      )IR";
1556
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1557
}
1558

1559
TEST(Reductions, ReductionSplitCacheConsumerAccess) {
1560
  BufHandle a("a", {24, 32, 12}, kFloat);
1561
  BufHandle b("b", {24, 32, 12}, kFloat);
1562

1563
  Tensor c = Compute(
1564
      "scale",
1565
      {24, 32, 12},
1566
      [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1567
        return b.load(l, n, m) * a.load(l, n, m);
1568
      });
1569
  Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1570

1571
  Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1572
    return b.load(0, 0, l) * d.load(l);
1573
  });
1574

1575
  LoopNest l({e}, {c, d, e});
1576

1577
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1578
  ForPtr inner;
1579

1580
  // Split outer reduction axis.
1581
  LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner);
1582

1583
  // Split reduction consumer.
1584
  LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner);
1585

1586
  l.cacheAccesses(d.buf(), "sum_local", inner);
1587
  l.prepareForCodegen();
1588

1589
  StmtPtr result =
1590
      LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1591
  SimpleIREvaluator cg(result, {a, b, e});
1592

1593
  // reduction changes but cache does not.
1594
  std::ostringstream oss;
1595
  oss << *cg.stmt();
1596
  const std::string& expected_ir =
1597
      R"IR(
1598
#CHECK: Alias(sum_local,scale);
1599
#CHECK:         sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]);
1600
#CHECK: for (int i_2 = 0; i_2 < 6
1601
#CHECK:   for (int j_2 = 0; j_2 < 4
1602
#CHECK:     sum_local[j_2] = sum[j_2 + 4 * i_2];
1603
#CHECK:   for (int j_3 = 0; j_3 < 4
1604
#CHECK:     scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
1605
      )IR";
1606
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1607
}
1608

1609
TEST(Reductions, ReductionReorderCacheConsumerAccess) {
1610
  BufHandle a("a", {24, 32, 12}, kFloat);
1611
  BufHandle b("b", {24, 32, 12}, kFloat);
1612

1613
  Tensor c = Compute(
1614
      "scale",
1615
      {24, 32, 12},
1616
      [&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1617
        return b.load(l, n, m) * a.load(l, n, m);
1618
      });
1619
  Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1620

1621
  Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1622
    return b.load(0, 0, l) * d.load(l);
1623
  });
1624

1625
  LoopNest l({e}, {c, d, e});
1626

1627
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1628
  ForPtr inner;
1629

1630
  // reorder outer reduction axes.
1631
  auto loops = l.getLoopStmtsFor(d);
1632
  LoopNest::reorderAxis(loops[0], loops[1]);
1633

1634
  // Split reduction consumer.
1635
  LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner);
1636

1637
  l.cacheAccesses(d.buf(), "sum_local", inner);
1638
  l.prepareForCodegen();
1639

1640
  StmtPtr result =
1641
      LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1642
  SimpleIREvaluator cg(result, {a, b, e});
1643

1644
  // neither reduction body not cache changes.
1645
  std::ostringstream oss;
1646
  oss << *cg.stmt();
1647
  const std::string& expected_ir =
1648
      R"IR(
1649
#CHECK:        sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]);
1650
#CHECK:  for (int i_3 = 0; i_3 < 6;
1651
#CHECK:    for (int j_2 = 0; j_2 < 4;
1652
#CHECK:      sum_local[j_2] = sum[j_2 + 4 * i_3];
1653
#CHECK:    for (int j_3 = 0; j_3 < 4;
1654
#CHECK:      scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]);
1655
      )IR";
1656
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1657
}
1658

1659
TEST(Reductions, ReductionRfactorCacheTempOuter) {
1660
  const int M = 10;
1661
  const int N = 10;
1662
  const int K = 10;
1663
  VarHandle m("m", kInt);
1664
  VarHandle n("n", kInt);
1665
  VarHandle k("k", kInt);
1666

1667
  BufHandle b("B", {m, n, k}, kFloat);
1668
  std::vector<float> in(M * N * K);
1669
  for (int j = 0; j < M * N * K; ++j) {
1670
    in[j] = j;
1671
  }
1672

1673
  std::vector<float> out(1, -1.f);
1674

1675
  Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
1676
  LoopNest loop({c});
1677

1678
  std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1679
  LoopNest::reorderAxis(loops.at(0), loops.at(1));
1680
  loops = loop.getLoopStmtsFor(c);
1681
  auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1682
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1683
  BufPtr rfac_buf;
1684
  ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf));
1685
  loop.distributeLoop(loops.at(0));
1686

1687
  auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1688
  ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1689
  LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]);
1690

1691
  all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1692
  LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]);
1693
  loop.simplify();
1694
  loop.prepareForCodegen();
1695
  StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
1696
  SimpleIREvaluator cg(s, {b, c, m, n, k});
1697

1698
  std::ostringstream oss;
1699
  oss << *cg.stmt();
1700
  const std::string& expected_ir =
1701
      R"IR(
1702
#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
1703
#CHECK: Allocate(tmp); // dtype=float, dims=[n]
1704
#CHECK: for (int i_1 = 0; i_1 < m
1705
#CHECK:   for (int j = 0; j < n
1706
#CHECK:     tmp[j] = 0
1707
#CHECK:   }
1708
#CHECK:   for (int j_1 = 0; j_1 < n
1709
#CHECK:     for (int k
1710
#CHECK:       tmp[j_1] = (tmp[j_1]) + (B[
1711
#CHECK:     }
1712
#CHECK:   }
1713
#CHECK:   for (int j_2 = 0; j_2 < n
1714
#CHECK:     sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]);
1715
#CHECK:   }
1716
#CHECK:   Free(tmp);
1717
#CHECK-NOT: tmp
1718
      )IR";
1719
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1720

1721
  cg.call({in, out, M, N, K});
1722
  ASSERT_EQ(out[0], 499500);
1723
}
1724

1725
TEST(Reductions, ReductionRfactorCacheTempInner) {
1726
  const int M = 10;
1727
  const int N = 10;
1728
  const int K = 10;
1729
  VarHandle m("m", kInt);
1730
  VarHandle n("n", kInt);
1731
  VarHandle k("k", kInt);
1732

1733
  BufHandle b("B", {m, n, k}, kFloat);
1734
  std::vector<float> in(M * N * K);
1735
  for (int j = 0; j < M * N * K; ++j) {
1736
    in[j] = j;
1737
  }
1738

1739
  std::vector<float> out(1, -1.f);
1740

1741
  Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
1742
  LoopNest loop({c});
1743
  std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1744
  auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1745

1746
  LoopNest::reorderAxis(loops.at(0), loops.at(1));
1747
  loops = loop.getLoopStmtsFor(c);
1748
  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1749
  BufPtr rfac_buf;
1750
  ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf));
1751
  loop.distributeLoop(loops.at(0));
1752
  auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1753
  ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1754
  LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]);
1755

1756
  all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1757
  ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1758
  LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]);
1759
  loop.prepareForCodegen();
1760
  loop.simplify();
1761
  StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
1762
  SimpleIREvaluator cg(s, {b, c, m, n, k});
1763

1764
  std::ostringstream oss;
1765
  oss << *cg.stmt();
1766
  const std::string& expected_ir =
1767
      R"IR(
1768
#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
1769
#CHECK: Allocate(tmp); // dtype=float, dims=[1]
1770
#CHECK: for (int i_1 = 0; i_1 < m
1771
#CHECK:   for (int j = 0; j < n
1772
#CHECK:     tmp[0] = 0
1773
#CHECK:     for (int k
1774
#CHECK:       tmp[0] = (tmp[0]) + (B[
1775
#CHECK:     }
1776
#CHECK:   sum_rfac[j] = (sum_rfac[j]) + (tmp[0]);
1777
#CHECK:   Free(tmp);
1778
#CHECK-NOT: tmp
1779
      )IR";
1780
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1781

1782
  cg.call({in, out, M, N, K});
1783
  ASSERT_EQ(out[0], 499500);
1784
}
1785

1786
TEST(Reductions, ReductionVectorize) {
1787
  std::vector<float> in_(8 * 8);
1788
  for (const auto i : c10::irange(8)) {
1789
    for (const auto j : c10::irange(8)) {
1790
      in_[i * 8 + j] = i;
1791
    }
1792
  }
1793
  std::vector<float> out_before(8, -1.f);
1794
  std::vector<float> out_after(8, -1.f);
1795

1796
  BufHandle in("in", {8, 8}, kFloat);
1797

1798
  Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
1799
  LoopNest l_before({tensor});
1800
  LoopNest l(l_before);
1801
  l_before.prepareForCodegen();
1802
  SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor});
1803
  cg_before.call({in_, out_before});
1804

1805
  ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0]));
1806

1807
  StmtPtr s = l.root_stmt();
1808
  s = LoopNest::sanitizeNames(IRSimplifier::simplify(s));
1809

1810
  std::ostringstream oss;
1811
  oss << *s;
1812
  const std::string& expected_ir =
1813
      R"IR(
1814
#CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8);
1815
#CHECK: for (int i = 0; i < 8; i++) {
1816
#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i});
1817
#CHECK: }
1818
      )IR";
1819
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1820

1821
  // Vectorizing should not change result.
1822
  l.prepareForCodegen();
1823
  s = IRSimplifier::simplify(l.root_stmt());
1824
  SimpleIREvaluator cg_after(s, {in, tensor});
1825
  cg_after.call({in_, out_after});
1826
  for (const auto i : c10::irange(8)) {
1827
    ASSERT_EQ(out_before[i], out_after[i]);
1828
  }
1829
}
1830

1831
TEST(Reductions, ReductionVectorizeInner) {
1832
  BufHandle in("in", {8, 8}, kFloat);
1833

1834
  Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
1835
  LoopNest l({tensor});
1836

1837
  ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1]));
1838
}
1839

1840
TEST(Reductions, ReductionVectorizeRfactor) {
1841
  std::vector<float> in_(8 * 8);
1842
  for (const auto i : c10::irange(8)) {
1843
    for (const auto j : c10::irange(8)) {
1844
      in_[i * 8 + j] = i;
1845
    }
1846
  }
1847
  std::vector<float> out_before(1, -1.f);
1848
  std::vector<float> out_after(1, -1.f);
1849

1850
  BufHandle in("in", {8, 8}, kFloat);
1851

1852
  Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8});
1853

1854
  LoopNest l_before({tensor});
1855
  LoopNest l(l_before);
1856
  l_before.prepareForCodegen();
1857
  SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor});
1858
  cg_before.call({in_, out_before});
1859

1860
  ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1]));
1861

1862
  // But if we rfactor this so it's not a reduce axis we can vectorize that
1863
  // loop.
1864
  std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
1865
  LoopNest::reorderAxis(loops[0], loops[1]);
1866
  loops = l.getLoopStmtsFor(tensor);
1867
  auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1];
1868
  BufPtr rfac_buf = nullptr;
1869
  ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf));
1870

1871
  LoopNest::distributeLoop(loops.at(0));
1872
  auto rfac_loops = l.getAllLoopNestsWritingToBuf(rfac_buf);
1873

1874
  ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0]));
1875
  l.simplify();
1876

1877
  StmtPtr s = LoopNest::sanitizeNames(l.root_stmt());
1878

1879
  std::ostringstream oss;
1880
  oss << *s;
1881
  const std::string& expected_ir =
1882
      R"IR(
1883
#CHECK: sum = 0.f;
1884
#CHECK: for (int i = 0; i < 8; i++) {
1885
#CHECK:   sum_rfac[i] = 0.f;
1886
#CHECK: }
1887
#CHECK: for (int i_1 = 0; i_1 < 8; i_1++) {
1888
#CHECK:   sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1});
1889
#CHECK: }
1890
#CHECK: for (int i_2 = 0; i_2 < 8; i_2++) {
1891
#CHECK:   sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2});
1892
#CHECK: }
1893
      )IR";
1894
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1895

1896
  // Vectorizing should not change result.
1897
  l.prepareForCodegen();
1898
  s = IRSimplifier::simplify(l.root_stmt());
1899
  SimpleIREvaluator cg_after(s, {in, tensor});
1900
  cg_after.call({in_, out_after});
1901

1902
  ASSERT_EQ(out_before[0], out_after[0]);
1903
}
1904

1905
TEST(Reductions, InitFunction) {
1906
  constexpr int M = 32;
1907
  constexpr int N = 16;
1908
  BufHandle A("A", {M, N}, kFloat);
1909
  BufHandle B("B", {N}, kFloat);
1910
  Tensor C = Reduce(
1911
      "C",
1912
      {N},
1913
      Sum(),
1914
      [&](const std::vector<VarHandle>& v) { return B.load(v[0]); },
1915
      [&](const std::vector<VarHandle>& v) { return A.load(v[1], v[0]); },
1916
      {M});
1917
  LoopNest nest({C});
1918
  nest.prepareForCodegen();
1919
  StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt()));
1920
  std::ostringstream oss;
1921
  oss << *s << "\n";
1922
  const std::string& expected_ir =
1923
      R"IR(
1924
#CHECK:  for (int i = 0; i < 16; i++) {
1925
#CHECK:    C[i] = B[i];
1926
#CHECK:    for (int j = 0; j < 32; j++) {
1927
#CHECK:      C[i] = (C[i]) + (A[i + 16 * j]);
1928
#CHECK:    }
1929
#CHECK:  }
1930
      )IR";
1931
  torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1932
}
1933
} // namespace jit
1934
} // namespace torch
1935

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.