pytorch

Форк
0
/
test_registerizer.cpp 
3702 строки · 87.4 Кб
1
#include <gtest/gtest.h>
2
#include "test/cpp/tensorexpr/test_base.h"
3

4
#include "test/cpp/tensorexpr/test_utils.h"
5
#include "torch/csrc/jit/tensorexpr/ir_simplifier.h"
6
#include "torch/csrc/jit/tensorexpr/registerizer.h"
7

8
#include <iostream>
9

10
namespace torch {
11
namespace jit {
12
using namespace torch::jit::tensorexpr;
13

14
// Can replace a simple scalar access with a local variable.
15
TEST(Registerizer, RegisterizerSimple) {
16
  BufHandle a("A", {1}, kInt);
17
  VarHandle x("x", kInt);
18
  StmtPtr stmt = Block::make(
19
      {Store::make(a, {0}, 0),
20
       For::make(
21
           x,
22
           0,
23
           10,
24
           Block::make(
25
               {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
26

27
  /*
28
   * A[0] = 0;
29
   * for (int x = 0; x < 10; x++) {
30
   *   A[0] = (A[0]) + x;
31
   * }
32
   */
33

34
  stmt = registerize(stmt);
35

36
  /*
37
   * int A_1 = 0;
38
   * for (int x = 0; x < 10; x++) {
39
   *   A_1 = x + A_1;
40
   * }
41
   * A[0] = A_1;
42
   */
43

44
  std::ostringstream oss;
45
  oss << *stmt;
46

47
  const std::string& verification_pattern =
48
      R"IR(
49
# CHECK: int A_1 = 0;
50
# CHECK: for (int x = 0; x < 10; x++)
51
# CHECK-NOT: A[
52
# CHECK:   A_1 =
53
# CHECK: A[0] = A_1;)IR";
54

55
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
56
}
57

58
// Won't do replacement of a loop access.
59
TEST(Registerizer, RegisterizerLoop) {
60
  BufHandle a("A", {10}, kInt);
61
  VarHandle x("x", kInt);
62
  StmtPtr stmt = Block::make(
63
      {Store::make(a, {0}, 0),
64
       For::make(
65
           x,
66
           0,
67
           10,
68
           Block::make(
69
               {Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
70

71
  /*
72
   * A[0] = 0;
73
   * for (int x = 0; x < 10; x++) {
74
   *   A[x] = (A[x]) + x;
75
   * }
76
   */
77

78
  // No change.
79
  stmt = registerize(stmt);
80

81
  /*
82
   * A[0] = 0;
83
   * for (int x = 0; x < 10; x++) {
84
   *   A[x] = (A[x]) + x;
85
   * }
86
   */
87

88
  std::ostringstream oss;
89
  oss << *stmt;
90

91
  const std::string& verification_pattern =
92
      R"IR(
93
# CHECK-NOT: int
94
# CHECK: A[0] = 0;
95
# CHECK: for (int x = 0; x < 10; x++)
96
# CHECK-NOT: A_
97
# CHECK:   A[x] =
98
# CHECK-NOT: A[0] = A_1;)IR";
99

100
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
101
}
102

103
// Won't replace even if the load is a fixed scalar, since the store could
104
// invalidate it.
105
TEST(Registerizer, RegisterizerLoopFixedLoad) {
106
  BufHandle a("A", {1}, kInt);
107
  VarHandle x("x", kInt);
108
  StmtPtr stmt = Block::make(
109
      {Store::make(a, {0}, 0),
110
       For::make(
111
           x,
112
           0,
113
           10,
114
           Block::make(
115
               {Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))});
116

117
  /*
118
   * A[0] = 0;
119
   * for (int x = 0; x < 10; x++) {
120
   *   A[x] = (A[0]) + x;
121
   * }
122
   */
123

124
  // No change.
125
  stmt = registerize(stmt);
126

127
  /*
128
   * A[0] = 0;
129
   * for (int x = 0; x < 10; x++) {
130
   *   A[x] = (A[0]) + x;
131
   * }
132
   */
133

134
  std::ostringstream oss;
135
  oss << *stmt;
136

137
  const std::string& verification_pattern =
138
      R"IR(
139
# CHECK-NOT: int
140
# CHECK: A[0] = 0;
141
# CHECK: for (int x = 0; x < 10; x++)
142
# CHECK-NOT: A_
143
# CHECK:   A[x] =
144
# CHECK-NOT: A[0] = A_1;)IR";
145

146
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
147
}
148

149
// We can registerize accesses that occur entirely within inner scopes, even if
150
// they depend on the loop var.
151
TEST(Registerizer, RegisterizerLoopInternal) {
152
  BufHandle a("A", {1}, kInt);
153
  VarHandle x("x", kInt);
154
  StmtPtr stmt = Block::make({For::make(
155
      x,
156
      0,
157
      10,
158
      Block::make(
159
          {Store::make(a, {x}, Add::make(Load::make(a, {x}), x)),
160
           Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
161

162
  /*
163
   * for (int x = 0; x < 10; x++) {
164
   *   A[x] = (A[x]) + x;
165
   *   A[x] = (A[x]) + x;
166
   * }
167
   */
168

169
  stmt = registerize(stmt);
170

171
  // TODO: the order of terms in addition changes and in general depends on
172
  // some hash value. This results in unpredictable swaps of the operands from
173
  // random changes, which is not great. Ideally, we should ensure some
174
  // specific order (ideally, the original one).
175
  /*
176
   * for (int x = 0; x < 10; x++) {
177
   *   int A_1 = A[x];
178
   *   A_1 = x + A_1;
179
   *   A_1 = x + A_1;
180
   *   A[x] = A_1;
181
   * }
182
   */
183

184
  std::ostringstream oss;
185
  oss << *stmt;
186

187
  const std::string& verification_pattern =
188
      R"IR(
189
# CHECK: for (int x = 0; x < 10; x++)
190
# CHECK: int A_1 = A[x];
191
# CHECK:   A_1 = A_1 + x;
192
# CHECK:   A_1 = A_1 + x;
193
# CHECK:   A[x] = A_1;
194
# CHECK: })IR";
195

196
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
197
}
198

199
// An access can be overlapped by another read in the same Expr. In this case
200
// B[z] and B[y] overlap and prevent registerization of both accesses.
201
TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) {
202
  BufHandle a("A", {10}, kInt);
203
  BufHandle b("B", {10}, kInt);
204
  VarHandle x("x", kInt);
205
  VarHandle y("y", kInt);
206
  VarHandle z("z", kInt);
207
  StmtPtr stmt = Block::make({For::make(
208
      x,
209
      0,
210
      10,
211
      Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))});
212
  stmt = IRSimplifier::simplify(stmt);
213

214
  /*
215
   * for (int x = 0; x < 10; x++) {
216
   *   A[x] = (B[y]) + (B[z]);
217
   * }
218
   */
219

220
  std::ostringstream before;
221
  before << *stmt;
222

223
  // No change.
224
  stmt = registerize(stmt);
225

226
  std::ostringstream after;
227
  after << *stmt;
228

229
  ASSERT_EQ(before.str(), after.str());
230
}
231

232
TEST(Registerizer, RegisterizerLoopInternalRepeated) {
233
  BufHandle a("A", {1}, kInt);
234
  VarHandle x("x", kInt);
235
  StmtPtr stmt = Block::make(
236
      {For::make(
237
           x,
238
           0,
239
           10,
240
           Block::make(
241
               {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
242
                Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})),
243
       For::make(
244
           x,
245
           0,
246
           10,
247
           Block::make(
248
               {Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),
249
                Store::make(a, {0}, Add::make(Load::make(a, {1}), x))}))
250

251
      });
252

253
  /*
254
   * for (int x = 0; x < 10; x++) {
255
   *   A[0] = x + (A[1]);
256
   *   A[0] = x + (A[1]);
257
   * }
258
   * for (int x = 0; x < 10; x++) {
259
   *   A[0] = x + (A[1]);
260
   *   A[0] = x + (A[1]);
261
   * }
262
   */
263

264
  stmt = registerize(stmt);
265

266
  /*
267
   * int A_1 = A[1];
268
   * int A_2 = A[0];
269
   * for (int x = 0; x < 10; x++) {
270
   *   A_2 = A_1 + x;
271
   *   A_2 = A_1 + x;
272
   * }
273
   * for (int x = 0; x < 10; x++) {
274
   *   A_2 = A_1 + x;
275
   *   A_2 = A_1 + x;
276
   * }
277
   * A[0] = A_2;
278
   */
279

280
  std::ostringstream oss;
281
  oss << *stmt;
282

283
  const std::string& verification_pattern =
284
      R"IR(
285
# CHECK: int A_1 = A[1];
286
# CHECK: int A_2 = A[0];
287
# CHECK: for (int x = 0; x < 10; x++)
288
# CHECK:   A_2 = A_1 + x;
289
# CHECK:   A_2 = A_1 + x;
290
# CHECK: }
291
# CHECK: for (int x = 0; x < 10; x++)
292
# CHECK:   A_2 = A_1 + x;
293
# CHECK:   A_2 = A_1 + x;
294
# CHECK: }
295
# CHECK-NOT: A[1]
296
# CHECK: A[0] = A_2;
297
# CHECK-NOT: A[1]
298
# CHECK: })IR";
299

300
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
301
}
302

303
TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) {
304
  BufHandle a("A", {1}, kInt);
305
  VarHandle x("x", kInt);
306
  StmtPtr stmt = Block::make(
307
      {For::make(
308
           x,
309
           0,
310
           10,
311
           Block::make(
312
               {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
313
                Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})),
314
       For::make(
315
           x,
316
           0,
317
           10,
318
           Block::make(
319
               {Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),
320
                Store::make(a, {0}, Add::make(Load::make(a, {x}), x))}))
321

322
      });
323
  stmt = IRSimplifier::simplify(stmt);
324

325
  /*
326
   * for (int x = 0; x < 10; x++) {
327
   *   A[0] = (A[x]) + x;
328
   *   A[0] = (A[x]) + x;
329
   * }
330
   * for (int x = 0; x < 10; x++) {
331
   *   A[0] = (A[x]) + x;
332
   *   A[0] = (A[x]) + x;
333
   * }
334
   */
335

336
  std::ostringstream before;
337
  before << *stmt;
338

339
  // No change.
340
  stmt = registerize(stmt);
341

342
  std::ostringstream after;
343
  after << *stmt;
344

345
  ASSERT_EQ(before.str(), after.str());
346
}
347

348
TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) {
349
  BufHandle a("A", {1}, kInt);
350
  VarHandle x("x", kInt);
351
  VarHandle y("y", kInt);
352
  StmtPtr stmt = IRSimplifier::simplify(Block::make(
353
      {For::make(
354
           x,
355
           0,
356
           10,
357
           Block::make(
358
               {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
359
                Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})),
360
       For::make(
361
           x,
362
           0,
363
           10,
364
           Block::make(
365
               {Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),
366
                Store::make(a, {0}, Add::make(x, Load::make(a, {y})))}))
367

368
      }));
369

370
  /*
371
   * for (int x = 0; x < 10; x++) {
372
   *   A[0] = (A[x]) + x;
373
   *   A[0] = (A[x]) + x;
374
   * }
375
   * for (int x = 0; x < 10; x++) {
376
   *   A[0] = (A[x]) + x;
377
   *   A[0] = (A[x]) + x;
378
   * }
379
   */
380

381
  std::ostringstream before;
382
  before << *stmt;
383

384
  // No change.
385
  stmt = registerize(stmt);
386

387
  std::ostringstream after;
388
  after << *stmt;
389

390
  ASSERT_EQ(before.str(), after.str());
391
}
392

393
// Will registerize multiple accesses of different items of the same buffer.
394
TEST(Registerizer, RegisterizerMultiVar) {
395
  BufHandle a("A", {2}, kInt);
396
  VarHandle x("x", kInt);
397
  StmtPtr stmt = Block::make({
398
      Store::make(a, {0}, 0),
399
      Store::make(a, {1}, 0),
400
      For::make(
401
          x,
402
          0,
403
          10,
404
          Block::make(
405
              {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
406
               Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
407
  });
408

409
  /*
410
   * A[0] = 0;
411
   * A[1] = 0;
412
   * for (int x = 0; x < 10; x++) {
413
   *   A[0] = (A[0]) + x;
414
   *   A[1] = (A[1]) - x;
415
   * }
416
   */
417

418
  stmt = registerize(stmt);
419

420
  /*
421
   * int A_1 = 0;
422
   * int A_2 = 0;
423
   * for (int x = 0; x < 10; x++) {
424
   *   A_2 = x + A_2;
425
   *   A_1 = A_1 - x;
426
   * }
427
   * A[1] = A_2;
428
   * A[0] = A_1;
429
   */
430

431
  std::ostringstream oss;
432
  oss << *stmt;
433

434
  const std::string& verification_pattern =
435
      R"IR(
436
# CHECK: int A_1 = 0;
437
# CHECK: int A_2 = 0;
438
# CHECK: for (int x = 0; x < 10; x++)
439
# CHECK-NOT: A[
440
# CHECK:   A_1 =
441
# CHECK:   A_2 =
442
# CHECK: A[1] = A_2
443
# CHECK: A[0] = A_1;)IR";
444

445
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
446
}
447

448
// Will registerize the valid accesses while skipping invalid replacements.
449
TEST(Registerizer, RegisterizerVariableLoad) {
450
  BufHandle a("A", {1}, kInt);
451
  BufHandle b("B", {10}, kInt);
452
  VarHandle x("x", kInt);
453
  VarHandle x2("x", kInt);
454
  StmtPtr stmt = Block::make(
455
      {Store::make(a, {0}, 0),
456
       For::make(x, 0, 10, Store::make(b, {x}, x)),
457
       For::make(
458
           x2,
459
           0,
460
           10,
461
           Block::make({Store::make(
462
               a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))});
463

464
  /*
465
   * A[0] = 0;
466
   * for (int x = 0; x < 10; x++) {
467
   *   B[x] = x;
468
   * }
469
   * for (int x_1 = 0; x_1 < 10; x_1++) {
470
   *   A[0] = (A[0]) + (B[x_1]);
471
   * }
472
   */
473

474
  stmt = registerize(stmt);
475

476
  /*
477
   * int A_1 = 0;
478
   * for (int x = 0; x < 10; x++) {
479
   *   B[x] = x;
480
   * }
481
   * for (int x_1 = 0; x_1 < 10; x_1++) {
482
   *   A_1 = A_1 + (B[x_1]);
483
   * }
484
   * A[0] = A_1;
485
   */
486

487
  std::ostringstream oss;
488
  oss << *stmt;
489

490
  const std::string& verification_pattern =
491
      R"IR(
492
# CHECK: int A_1 = 0;
493
# CHECK: for (int x = 0; x < 10; x++)
494
# CHECK:   B[x] = x
495
# CHECK: for (int x_1 = 0; x_1 < 10; x_1++)
496
# CHECK-NOT: A[
497
# CHECK:   A_1 =
498
# CHECK: A[0] = A_1;)IR";
499

500
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
501
}
502

503
// Can registerize variable accesses so long as the variable does not change.
504
TEST(Registerizer, RegisterizerSymbolicIndices) {
505
  VarHandle i("i", kInt);
506
  VarHandle N("N", kInt);
507
  BufHandle a("A", {N}, kInt);
508
  VarHandle x("x", kInt);
509
  StmtPtr stmt = Block::make(
510
      {Store::make(a, {i}, 0),
511
       For::make(
512
           x,
513
           0,
514
           10,
515
           Block::make(
516
               {Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))});
517

518
  /*
519
   * A[i] = 0;
520
   * for (int x = 0; x < 10; x++) {
521
   *   A[i] = (A[i]) + x;
522
   * }
523
   */
524

525
  stmt = registerize(stmt);
526

527
  /*
528
   * int A_1 = 0;
529
   * for (int x = 0; x < 10; x++) {
530
   *   A_1 = x + A_1;
531
   * }
532
   * A[i] = A_1;
533
   */
534

535
  std::ostringstream oss;
536
  oss << *stmt;
537

538
  const std::string& verification_pattern =
539
      R"IR(
540
# CHECK: int A_1 = 0;
541
# CHECK: for (int x = 0; x < 10; x++)
542
# CHECK-NOT: A[
543
# CHECK:   A_1 =
544
# CHECK: A[i] = A_1;)IR";
545

546
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
547
}
548

549
// Can registerize accesses dependent on multiple loop vars.
550
TEST(Registerizer, RegisterizerMultiLoop) {
551
  BufHandle a("A", {1}, kInt);
552
  VarHandle x("x", kInt);
553
  VarHandle y("y", kInt);
554
  StmtPtr stmt = Block::make(
555
      {Store::make(a, {0}, 0),
556
       For::make(
557
           x,
558
           0,
559
           10,
560
           For::make(
561
               y,
562
               0,
563
               10,
564
               Block::make({Store::make(
565
                   a,
566
                   {0},
567
                   Mul::make(Add::make(Load::make(a, {0}), x), y))})))});
568

569
  /*
570
   * A[0] = 0;
571
   * for (int x = 0; x < 10; x++) {
572
   *   for (int y = 0; y < 10; y++) {
573
   *     A[0] = x * y + (A[0]) * y;
574
   *   }
575
   * }
576
   */
577

578
  stmt = registerize(stmt);
579

580
  /*
581
   * int A_1 = 0;
582
   * for (int x = 0; x < 10; x++) {
583
   *   for (int y = 0; y < 10; y++) {
584
   *     A_1 = x * y + y * A_1;
585
   *   }
586
   * }
587
   * A[0] = A_1;
588
   */
589

590
  std::ostringstream oss;
591
  oss << *stmt;
592

593
  const std::string& verification_pattern =
594
      R"IR(
595
# CHECK: int A_1 = 0;
596
# CHECK: for (int x = 0; x < 10; x++)
597
# CHECK:   for (int y = 0; y < 10; y++)
598
# CHECK-NOT: A[
599
# CHECK:     A_1 =
600
# CHECK: A[0] = A_1;)IR";
601

602
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
603
}
604

605
// Can registerize correctly if scalars already exist in the program.
606
TEST(Registerizer, RegisterizerRepeated) {
607
  BufHandle a("A", {2}, kInt);
608
  VarHandle x("x", kInt);
609
  StmtPtr stmt = Block::make({
610
      Store::make(a, {0}, 0),
611
      Store::make(a, {1}, 0),
612
      For::make(
613
          x,
614
          0,
615
          10,
616
          Block::make(
617
              {Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),
618
               Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),
619
  });
620

621
  // Registerize manually to make sure we only replace a single target.
622
  {
623
    registerizer::RegisterizerAnalysis analysis;
624
    stmt->accept(&analysis);
625
    auto candidates = analysis.getCandidates();
626
    ASSERT_EQ(candidates.size(), 2);
627

628
    candidates.pop_back();
629
    registerizer::RegisterizerReplacer replacer(candidates);
630
    stmt = stmt->accept_mutator(&replacer);
631
  }
632

633
  // Re-analyze and replace the second target.
634
  {
635
    registerizer::RegisterizerAnalysis analysis;
636
    stmt->accept(&analysis);
637
    auto candidates = analysis.getCandidates();
638
    ASSERT_EQ(candidates.size(), 1);
639

640
    registerizer::RegisterizerReplacer replacer(candidates);
641
    stmt = stmt->accept_mutator(&replacer);
642
  }
643

644
  std::ostringstream oss;
645
  oss << *stmt;
646

647
  const std::string& verification_pattern =
648
      R"IR(
649
# CHECK: int A_1 = 0;
650
# CHECK: int A_1_1 = 0;
651
# CHECK: for (int x = 0; x < 10; x++)
652
# CHECK-NOT: A[
653
# CHECK:   A_1 =
654
# CHECK:   A_1_1 =
655
# CHECK: A[1] = A_1_1;
656
# CHECK: A[0] = A_1;)IR";
657

658
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
659
}
660

661
// Can registerize the load of A.
662
TEST(Registerizer, RegisterizerNoLoads) {
663
  BufHandle a("A", {1}, kInt);
664
  VarHandle x("x", kInt);
665
  StmtPtr stmt = Block::make(
666
      {Store::make(a, {0}, 0),
667
       For::make(
668
           x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))});
669

670
  /*
671
   * A[0] = 0;
672
   * for (int x = 0; x < 10; x++) {
673
   *   A[0] = x + 1;
674
   * }
675
   */
676

677
  stmt = registerize(stmt);
678

679
  /*
680
   * int A_1 = 0;
681
   * for (int x = 0; x < 10; x++) {
682
   *   A_1 = x + 1;
683
   * }
684
   * A[0] = A_1;
685
   */
686

687
  std::ostringstream oss;
688
  oss << *stmt;
689

690
  const std::string& verification_pattern =
691
      R"IR(
692
# CHECK: int A_1 = 0;
693
# CHECK: for (int x = 0; x < 10; x++)
694
# CHECK-NOT: A[
695
# CHECK:   A_1 =
696
# CHECK: A[0] = A_1;)IR";
697

698
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
699
}
700

701
// Can registerize the load of A but not the store of B.
702
TEST(Registerizer, RegisterizerNoRepeatedStores) {
703
  BufHandle a("A", {1}, kInt);
704
  BufHandle b("B", {10}, kInt);
705
  VarHandle x("x", kInt);
706
  StmtPtr stmt = Block::make(
707
      {Store::make(a, {0}, 0),
708
       For::make(
709
           x,
710
           0,
711
           10,
712
           Block::make(
713
               {Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))});
714

715
  /*
716
   * A[0] = 0;
717
   * for (int x = 0; x < 10; x++) {
718
   *   B[x] = (A[0]) + x;
719
   * }
720
   */
721

722
  stmt = registerize(stmt);
723

724
  // TODO: its unnecessary to reorder the initializer of A[0], but it's not
725
  // actually worse so lets not worry for now.
726

727
  /*
728
   * int A_1 = 0;
729
   * for (int x = 0; x < 10; x++) {
730
   *   B[x] = x + A_1;
731
   * }
732
   * A[0] = A_1;
733
   */
734

735
  std::ostringstream oss;
736
  oss << *stmt;
737

738
  const std::string& verification_pattern =
739
      R"IR(
740
# CHECK: int A_1 = 0;
741
# CHECK: for (int x = 0; x < 10; x++)
742
# CHECK-NOT: A_
743
# CHECK:   B[x] =
744
# CHECK: A[0] = A_1;)IR";
745

746
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
747
}
748

749
// Won't registerize if there are multiple accesses which may overlap.
750
TEST(Registerizer, RegisterizerMultiVarOverlap) {
751
  BufHandle a("A", {2}, kInt);
752
  VarHandle x("x", kInt);
753
  StmtPtr stmt = Block::make({
754
      Store::make(a, {0}, 0),
755
      Store::make(a, {1}, 0),
756
      For::make(
757
          x,
758
          0,
759
          10,
760
          Block::make(
761
              {Store::make(a, {x}, Add::make(Load::make(a, {0}), x)),
762
               Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})),
763
  });
764
  stmt = IRSimplifier::simplify(stmt);
765

766
  std::ostringstream before;
767
  before << *stmt;
768

769
  // No change.
770
  stmt = registerize(stmt);
771

772
  std::ostringstream after;
773
  after << *stmt;
774

775
  ASSERT_EQ(before.str(), after.str());
776
}
777

778
TEST(Registerizer, RegisterizerAllocs) {
779
  BufHandle a("A", {2}, kInt);
780
  BufHandle c("C", {1}, kInt);
781
  VarHandle x("x", kInt);
782

783
  BufHandle b("B", {Load::make(c, {0})}, kInt);
784

785
  StmtPtr stmt = Block::make(
786
      {Allocate::make(b),
787
       Store::make(a, {0}, Load::make(c, {0})),
788
       Store::make(b, {0}, 0),
789
       For::make(
790
           x,
791
           0,
792
           10,
793
           Block::make(
794
               {Store::make(b, {0}, Add::make(Load::make(b, {0}), x)),
795
                Store::make(a, {0}, Load::make(c, {0}))})),
796
       Free::make(b)});
797

798
  /*
799
   * Allocate(B, int, {C[0]});
800
   * A[0] = C[0];
801
   * B[0] = 0;
802
   * for (int x = 0; x < 10; x++) {
803
   *   B[0] = (B[0]) + x;
804
   *   A[0] = C[0];
805
   * }
806
   * Free(B);
807
   */
808

809
  stmt = registerize(stmt);
810

811
  /*
812
   * int C_1 = C[0];
813
   * Allocate(B, int, {C_});
814
   * int A_1 = C_1;
815
   * int B_1 = 0;
816
   * for (int x = 0; x < 10; x++) {
817
   *   B_1 = B_1 + x;
818
   *   A_1 = C_1;
819
   * }
820
   * B[0] = B_1;
821
   * A[0] = A_1;
822
   * Free(B);
823
   */
824

825
  std::ostringstream oss;
826
  oss << *stmt;
827

828
  const std::string& verification_pattern =
829
      R"IR(
830
# CHECK: int C_1 = C[0];
831
# CHECK: Allocate(B
832
# CHECK: int A_1 = C_1;
833
# CHECK: int B_1 = 0;
834
# CHECK: for (int x = 0; x < 10; x++)
835
# CHECK:   B_1 =
836
# CHECK:   A_1 = C_
837
# CHECK: B[0] = B_1;
838
# CHECK: A[0] = A_1;
839
# CHECK: Free(B)IR";
840

841
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
842
}
843

844
TEST(Registerizer, RegisterizerNoInitializer) {
845
  BufHandle a("A", {1}, kInt);
846
  VarHandle x("x", kInt);
847
  StmtPtr stmt = Block::make({For::make(
848
      x,
849
      0,
850
      10,
851
      Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
852

853
  /*
854
   * for (int x = 0; x < 10; x++) {
855
   *   A[0] = (A[0]) + x;
856
   * }
857
   */
858

859
  stmt = registerize(stmt);
860

861
  /*
862
   * int A_1 = A[0];
863
   * for (int x = 0; x < 10; x++) {
864
   *   A_1 = x + A_1;
865
   * }
866
   * A[0] = A_1;
867
   */
868

869
  std::ostringstream oss;
870
  oss << *stmt;
871

872
  const std::string& verification_pattern =
873
      R"IR(
874
# CHECK: int A_1 = A[0];
875
# CHECK: for (int x = 0; x < 10; x++)
876
# CHECK-NOT: A[
877
# CHECK:   A_1 =
878
# CHECK: A[0] = A_1;)IR";
879

880
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
881
}
882

883
TEST(Registerizer, RegisterizerNoInitializerLoopVar) {
884
  BufHandle a("A", {1}, kInt);
885
  VarHandle x("x", kInt);
886
  StmtPtr stmt = Block::make({For::make(
887
      x,
888
      0,
889
      10,
890
      Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});
891
  stmt = IRSimplifier::simplify(stmt);
892

893
  /*
894
   * for (int x = 0; x < 10; x++) {
895
   *   A[x] = (A[x]) + x;
896
   * }
897
   */
898

899
  std::ostringstream before;
900
  before << *stmt;
901

902
  // No change.
903
  stmt = registerize(stmt);
904

905
  std::ostringstream after;
906
  after << *stmt;
907

908
  ASSERT_EQ(before.str(), after.str());
909
}
910

911
TEST(Registerizer, RegisterizerLoadThenStore) {
912
  BufHandle a("A", {1}, kInt);
913
  BufHandle b("B", {1}, kInt);
914
  VarHandle x("x", kInt);
915
  StmtPtr stmt = Block::make({For::make(
916
      x,
917
      0,
918
      10,
919
      Block::make(
920
          {Store::make(b, {0}, Add::make(Load::make(a, {0}), x)),
921
           Store::make(a, {0}, Load::make(b, {0}))}))});
922

923
  /*
924
   * for (int x = 0; x < 10; x++) {
925
   *   B[0] = (A[0]) + x;
926
   *   A[0] = B[0];
927
   * }
928
   */
929

930
  stmt = registerize(stmt);
931

932
  /*
933
   * int A_1 = A[0];
934
   * int B_1 = B[0];
935
   * for (int x = 0; x < 10; x++) {
936
   *   B_1 = x + A_1;
937
   *   A_1 = B_1;
938
   * }
939
   * B[0] = B_1;
940
   * A[0] = A_1;
941
   */
942

943
  std::ostringstream oss;
944
  oss << *stmt;
945

946
  const std::string& verification_pattern =
947
      R"IR(
948
# CHECK: int A_1 = A[0];
949
# CHECK: int B_1 = B[0];
950
# CHECK: for (int x = 0; x < 10; x++)
951
# CHECK-NOT: B[
952
# CHECK:   B_1 =
953
# CHECK-NOT: A[
954
# CHECK:   A_1 = B_
955
# CHECK: B[0] = B_
956
# CHECK: A[0] = A_1;)IR";
957

958
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
959
}
960

961
TEST(Registerizer, RegisterizerParallelized) {
962
  BufHandle a("A", {1}, kInt);
963
  VarHandle x("x", kInt);
964
  LoopOptions loopOpts;
965
  loopOpts.set_gpu_block_index(0);
966
  StmtPtr stmt = Block::make(
967
      {Store::make(a, {0}, 0),
968
       For::make(
969
           x,
970
           0,
971
           10,
972
           Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}),
973
           loopOpts)});
974

975
  /*
976
   * A[0] = 0;
977
   * for (int x = 0; x < 10; x++) {
978
   *   A[0] = (A[0]) + x;
979
   * }
980
   */
981

982
  ASSERT_THROWS_WITH(
983
      registerize(stmt),
984
      "Registerization must occur after parallelism flattening");
985
}
986

987
// Should be able to registerize this since the scalar would exist before the
988
// branch.
989
TEST(Registerizer, RegisterizerConditionAfter) {
990
  BufHandle a("A", {5}, kInt);
991
  BufHandle b("B", {5}, kInt);
992
  BufHandle c("C", {5}, kInt);
993
  VarHandle x("x", kInt);
994

995
  StmtPtr stmt = Block::make(
996
      {Store::make(a, {x}, Load::make(b, {x})),
997
       Store::make(c, {x}, Load::make(a, {x})),
998
       Cond::make(
999
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1000
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1001
           nullptr)});
1002

1003
  /*
1004
   * A[x] = B[x];
1005
   * C[x] = A[x];
1006
   * if (x<5 ? 1 : 0) {
1007
   *   A[x] = (A[x]) + 1;
1008
   * }
1009
   */
1010

1011
  stmt = registerize(stmt);
1012

1013
  /*
1014
   * int A_1 = B[x];
1015
   * C[x] = A_1;
1016
   * if (x<5 ? 1 : 0) {
1017
   *   A_1 = A_1 + 1;
1018
   * }
1019
   * A[x] = A_1;
1020
   */
1021

1022
  std::ostringstream oss;
1023
  oss << *stmt;
1024

1025
  const std::string& verification_pattern =
1026
      R"IR(
1027
# CHECK: int A_1 = B[x];
1028
# CHECK: C[x] = A_1;
1029
# CHECK: if (
1030
# CHECK:   A_1 = A_1 + 1;
1031
# CHECK: A[x] = A_1;)IR";
1032

1033
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1034
}
1035

1036
// Should be able to registerize this since the scalar exists in the same form
1037
// after the branch and there is no overlap.
1038
TEST(Registerizer, RegisterizerConditionBefore) {
1039
  BufHandle a("A", {5}, kInt);
1040
  BufHandle b("B", {5}, kInt);
1041
  BufHandle c("C", {5}, kInt);
1042
  VarHandle x("x", kInt);
1043

1044
  StmtPtr stmt = Block::make(
1045
      {Cond::make(
1046
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1047
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1048
           nullptr),
1049
       Store::make(a, {x}, Load::make(b, {x})),
1050
       Store::make(c, {x}, Load::make(a, {x}))});
1051

1052
  /*
1053
   * if (x<5 ? 1 : 0) {
1054
   *   A[x] = (A[x]) + 1;
1055
   * }
1056
   * A[x] = B[x];
1057
   * C[x] = A[x];
1058
   */
1059

1060
  stmt = registerize(stmt);
1061

1062
  /*
1063
   * int A_ 1 = A[x];
1064
   * if (x<5 ? 1 : 0) {
1065
   *   A_1 = A_1 + 1;
1066
   * }
1067
   * A_1 = B[x];
1068
   * C[x] = A_1;
1069
   * A[x] = A_1;
1070
   */
1071

1072
  std::ostringstream oss;
1073
  oss << *stmt;
1074

1075
  const std::string& verification_pattern =
1076
      R"IR(
1077
# CHECK: int A_1 = A[x];
1078
# CHECK: if (
1079
# CHECK:   A_1 = A_1 + 1;
1080
# CHECK: }
1081
# CHECK: A_1 = B[x];
1082
# CHECK: C[x] = A_1;
1083
# CHECK: A[x] = A_1;)IR";
1084

1085
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1086
}
1087

1088
// Should be able to registerize this as the combination of the two above rules.
1089
TEST(Registerizer, RegisterizerConditionInside) {
1090
  BufHandle a("A", {5}, kInt);
1091
  BufHandle b("B", {5}, kInt);
1092
  BufHandle c("C", {5}, kInt);
1093
  VarHandle x("x", kInt);
1094

1095
  StmtPtr stmt = Block::make(
1096
      {Store::make(a, {x}, Load::make(b, {x})),
1097
       Store::make(c, {x}, Load::make(a, {x})),
1098
       Cond::make(
1099
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1100
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1101
           nullptr),
1102
       Store::make(b, {x}, Load::make(a, {x})),
1103
       Store::make(a, {x}, Load::make(c, {x}))});
1104

1105
  /*
1106
   * A[x] = B[x];
1107
   * C[x] = A[x];
1108
   * if (x<5 ? 1 : 0) {
1109
   *   A[x] = (A[x]) + 1;
1110
   * }
1111
   * B[x] = A[x];
1112
   * A[x] = C[x];
1113
   */
1114

1115
  stmt = registerize(stmt);
1116

1117
  /*
1118
   * int A_1 = B[x];
1119
   * C[x] = A_1;
1120
   * if (x<5 ? 1 : 0) {
1121
   *   A_1 = A_1 + 1;
1122
   * }
1123
   * B[x] = A_1;
1124
   * A_1 = C[x];
1125
   * A[x] = A_1;
1126
   */
1127

1128
  std::ostringstream oss;
1129
  oss << *stmt;
1130

1131
  const std::string& verification_pattern =
1132
      R"IR(
1133
# CHECK: int A_1 = B[x];
1134
# CHECK: C[x] = A_1;
1135
# CHECK: if (
1136
# CHECK:   A_1 = A_1 + 1;
1137
# CHECK: }
1138
# CHECK: B[x] = A_1;
1139
# CHECK: A_1 = C[x];
1140
# CHECK: A[x] = A_1;)IR";
1141

1142
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1143
}
1144

1145
// An example where an access is cut by an overlapping access inside a
1146
// condition, and both sides are large enough to be registerized but cannot be
1147
// because there is no safe place to put the initializer or finalizer.
1148
TEST(Registerizer, RegisterizerConditionInsideOverlap1) {
1149
  BufHandle a("A", {5}, kInt);
1150
  BufHandle b("B", {5}, kInt);
1151
  BufHandle c("C", {5}, kInt);
1152
  VarHandle x("x", kInt);
1153
  VarHandle y("y", kInt);
1154

1155
  StmtPtr stmt = Block::make(
1156
      // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
1157
      {Store::make(a, {x}, Load::make(b, {x})),
1158
       Store::make(c, {x}, Load::make(a, {x})),
1159
       Cond::make(
1160
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1161
           Block::make({
1162
               Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1163
               Store::make(a, {0}, 3),
1164
               Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1165
           }),
1166
           nullptr),
1167
       Store::make(b, {x}, Load::make(a, {x})),
1168
       Store::make(a, {x}, Load::make(c, {x}))});
1169

1170
  /*
1171
   * A[x] = B[x];
1172
   * C[x] = A[x];
1173
   * if (x<5 ? 1 : 0) {
1174
   *   A[x] = (A[x]) + 1;
1175
   *   A[0] = 3;
1176
   *   A[x] = (A[x]) + 1;
1177
   * }
1178
   * B[x] = A[x];
1179
   * A[x] = C[x];
1180
   */
1181

1182
  // The A[0] store overlaps, A[x] cutting the region that can be registerized
1183
  // into two groups.
1184
  // Each group has 2 loads and 2 stores however, so we could registerize it,
1185
  // but the first group would need to be finalized inside the condition block,
1186
  // the second would need to be initialized inside the condition block. There's
1187
  // no safe place to put these that's visible to the other uses in the group
1188
  // and so neither registerization is possible.
1189

1190
  std::ostringstream before;
1191
  before << *stmt;
1192

1193
  // No change.
1194
  stmt = registerize(stmt);
1195

1196
  std::ostringstream after;
1197
  after << *stmt;
1198

1199
  ASSERT_EQ(before.str(), after.str());
1200
}
1201

1202
// Same as the above, but the access group before the condition (and after the
1203
// condition) are large enough to be registerized without needing the access
1204
// from the loop. Registerization occurs but does not include any accesses in
1205
// the condition, and the first group must be finalized before the Cond, the
1206
// second initialized after it.
1207
TEST(Registerizer, RegisterizerConditionInsideOverlap2) {
1208
  BufHandle a("A", {5}, kInt);
1209
  BufHandle b("B", {5}, kInt);
1210
  BufHandle c("C", {5}, kInt);
1211
  VarHandle x("x", kInt);
1212
  VarHandle y("y", kInt);
1213

1214
  StmtPtr stmt = Block::make(
1215
      // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
1216
      {Store::make(a, {x}, Load::make(b, {x})),
1217
       Store::make(a, {x}, Load::make(b, {x + 1})),
1218
       Store::make(c, {x}, Load::make(a, {x})),
1219
       Cond::make(
1220
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1221
           Block::make({
1222
               Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1223
               Store::make(a, {0}, 3),
1224
               Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1225
           }),
1226
           nullptr),
1227
       Store::make(b, {x}, Load::make(a, {x})),
1228
       Store::make(b, {x + 1}, Load::make(a, {x})),
1229
       Store::make(a, {x}, Load::make(c, {x}))});
1230

1231
  /*
1232
   * A[x] = B[x];
1233
   * A[x] = B[x + 1];
1234
   * C[x] = A[x];
1235
   * if (x<5 ? 1 : 0) {
1236
   *   A[x] = (A[x]) + 1;
1237
   *   A[0] = 3;
1238
   *   A[x] = (A[x]) + 1;
1239
   * }
1240
   * B[x] = A[x];
1241
   * B[x + 1] = A[x];
1242
   * A[x] = C[x];
1243
   */
1244

1245
  stmt = registerize(stmt);
1246

1247
  /*
1248
   * int A_1 = B[x];              // A_1 initializer
1249
   * A_1 = B[x + 1];              //
1250
   * C[x] = A_1;                  //
1251
   * A[x] = A_1;                  // A_1 finalizer
1252
   * if (x<5 ? 1 : 0) {
1253
   *   A[x] = (A[x]) + 1;
1254
   *   A[0] = 3;
1255
   *   A[x] = (A[x]) + 1;
1256
   * }
1257
   * int A_2 = A[x];              // A_2 initialier
1258
   * B[x] = A_2;                  //
1259
   * B[x + 1] = A_2;              //
1260
   * A_2 = C[x];                  //
1261
   * A[x] = A_2;                  // A_2 finalizer
1262
   */
1263

1264
  std::ostringstream oss;
1265
  oss << *stmt;
1266

1267
  const std::string& verification_pattern =
1268
      R"IR(
1269
# CHECK: int A_1 = B[x];
1270
# CHECK: A_1 = B[x + 1];
1271
# CHECK: C[x] = A_1;
1272
# CHECK: A[x] = A_1;
1273
# CHECK: if (
1274
# CHECK-NOT:   A_1 = A_1 + 1;
1275
# CHECK:   A[x] = (A[x]
1276
# CHECK:   A[0] =
1277
# CHECK:   A[x] = (A[x]
1278
# CHECK: }
1279
# CHECK: int A_2 = A[x];
1280
# CHECK: B[x] = A_2;
1281
# CHECK: B[x + 1] = A_2;
1282
# CHECK: A_2 = C[x];
1283
# CHECK: A[x] = A_2;)IR";
1284

1285
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1286
}
1287

1288
// When accesses are within conditional blocks they are not visible to the wider
1289
// program, because we don't know if the branch would be taken and if it isn't
1290
// the accesses in it don't need to be valid (think size checks on the index).
1291
// In this case the accesses cannot be registerized.
1292
TEST(Registerizer, RegisterizerConditionHidden) {
1293
  BufHandle a("A", {5}, kInt);
1294
  BufHandle b("B", {5}, kInt);
1295
  BufHandle c("C", {5}, kInt);
1296
  VarHandle x("x", kInt);
1297

1298
  StmtPtr stmt = Block::make(
1299
      {Cond::make(
1300
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1301
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1302
           nullptr),
1303
       Cond::make(
1304
           CompareSelect::make(x, 5, CompareSelectOperation::kGT),
1305
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1306
           nullptr)});
1307

1308
  /*
1309
   * if (x<5 ? 1 : 0) {
1310
   *   A[x] = (A[x]) + 1;
1311
   * }
1312
   * if (x>5 ? 1 : 0) {
1313
   *   A[x] = (A[x]) + 1;
1314
   * }
1315
   */
1316

1317
  std::ostringstream before;
1318
  before << *stmt;
1319

1320
  // No change.
1321
  stmt = registerize(stmt);
1322

1323
  std::ostringstream after;
1324
  after << *stmt;
1325

1326
  ASSERT_EQ(before.str(), after.str());
1327
}
1328

1329
// But... if the same access is found in a non conditional scope, that means
1330
// that that access is valid in the higher scope (or at least if its not it's
1331
// the user's fault). It "unhides" the conditional accesses, allowing
1332
// registerization to occur.
1333
TEST(Registerizer, RegisterizerConditionUnhidden) {
1334
  BufHandle a("A", {5}, kInt);
1335
  BufHandle b("B", {5}, kInt);
1336
  BufHandle c("C", {5}, kInt);
1337
  VarHandle x("x", kInt);
1338

1339
  StmtPtr stmt = Block::make(
1340
      {Cond::make(
1341
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1342
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1343
           nullptr),
1344
       Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1345
       Cond::make(
1346
           CompareSelect::make(x, 5, CompareSelectOperation::kGT),
1347
           Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1348
           nullptr)});
1349

1350
  /*
1351
   * if (x<5 ? 1 : 0) {
1352
   *   A[x] = (A[x]) + 1;
1353
   * }
1354
   * A[x] = (A[x]) + 1;            <-- this is doing the unhiding.
1355
   * if (x>5 ? 1 : 0) {
1356
   *   A[x] = (A[x]) + 1;
1357
   * }
1358
   */
1359

1360
  stmt = registerize(stmt);
1361

1362
  /*
1363
   * int A_1 = A[x];
1364
   * if (x<5 ? 1 : 0) {
1365
   *   A_1 = A_1 + 1;
1366
   * }
1367
   * A_1 = A_1 + 1;
1368
   * if (x>5 ? 1 : 0) {
1369
   *   A_1 = A_1 + 1;
1370
   * }
1371
   * A[x] = A_1;
1372
   */
1373

1374
  std::ostringstream oss;
1375
  oss << *stmt;
1376

1377
  const std::string& verification_pattern =
1378
      R"IR(
1379
# CHECK: int A_1 = A[x];
1380
# CHECK: if (x<5
1381
# CHECK:   A_1 = A_1 + 1;
1382
# CHECK: }
1383
# CHECK: A_1 = A_1 + 1;
1384
# CHECK: if (x>5
1385
# CHECK:   A_1 = A_1 + 1;
1386
# CHECK: }
1387
# CHECK: A[x] = A_1;)IR";
1388

1389
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1390
}
1391

1392
// Can registerize a load that occurs in the condition of a Cond.
1393
TEST(Registerizer, RegisterizerCondCondition) {
1394
  BufHandle a("A", {5}, kInt);
1395
  BufHandle b("B", {5}, kInt);
1396
  BufHandle c("C", {5}, kInt);
1397
  VarHandle x("x", kInt);
1398

1399
  StmtPtr stmt = Block::make(
1400
      {Store::make(a, {x}, Load::make(b, {x})),
1401
       Store::make(c, {x}, Load::make(a, {x})),
1402
       Cond::make(
1403
           CompareSelect::make(
1404
               Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1405
           Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
1406
           nullptr)});
1407

1408
  /*
1409
   * A[x] = B[x];
1410
   * C[x] = A[x];
1411
   * if ((A[x])<5 ? 1 : 0) {
1412
   *   C[x] = (C[x]) + 1;
1413
   * }
1414
   */
1415

1416
  stmt = registerize(stmt);
1417

1418
  /*
1419
   * int A_1 = B[x];
1420
   * int C_1 = A_1;
1421
   * if (A_1<5 ? 1 : 0) {
1422
   *   C_1 = C_1 + 1;
1423
   * }
1424
   * C[x] = C_1;
1425
   */
1426

1427
  std::ostringstream oss;
1428
  oss << *stmt;
1429

1430
  const std::string& verification_pattern =
1431
      R"IR(
1432
# CHECK: int A_1 = B[x];
1433
# CHECK: int C_1 = A_1;
1434
# CHECK: if (A_1<5
1435
# CHECK:   C_1 = C_1 + 1;
1436
# CHECK: C[x] = C_1;)IR";
1437

1438
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1439
}
1440

1441
// Appearing in the condition of a Cond makes it visible to the enclosing scope,
1442
// and so we can registerize internal usages.
1443
TEST(Registerizer, RegisterizerCondConditionUnhidden) {
1444
  BufHandle a("A", {5}, kInt);
1445
  BufHandle b("B", {5}, kInt);
1446
  BufHandle c("C", {5}, kInt);
1447
  VarHandle x("x", kInt);
1448

1449
  StmtPtr stmt = Block::make({Cond::make(
1450
      CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1451
      Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),
1452
      Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))});
1453

1454
  /*
1455
   * if ((A[x])<5 ? 1 : 0) {
1456
   *   A[x] = (A[x]) + 1;
1457
   * } else {
1458
   *   A[x] = (A[x]) + 10;
1459
   * }
1460
   */
1461

1462
  stmt = registerize(stmt);
1463

1464
  /*
1465
   * int A_1 = A[x];
1466
   * if (A_1<5 ? 1 : 0) {
1467
   *   A_1 = A_1 + 1;
1468
   * } else {
1469
   *   A_1 = A_1 + 10;
1470
   * }
1471
   * A[x] = A_1;
1472
   */
1473

1474
  std::ostringstream oss;
1475
  oss << *stmt;
1476

1477
  const std::string& verification_pattern =
1478
      R"IR(
1479
# CHECK: int A_1 = A[x];
1480
# CHECK: if (A_1<5
1481
# CHECK:   A_1 = A_1 + 1;
1482
# CHECK: } else {
1483
# CHECK:   A_1 = A_1 + 10;
1484
# CHECK: }
1485
# CHECK: A[x] = A_1;)IR";
1486

1487
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1488
}
1489

1490
// Conditional hiding also works for IfThenElse exprs.
1491
TEST(Registerizer, RegisterizerIfThenElseHidden) {
1492
  BufHandle a("A", {5}, kInt);
1493
  BufHandle b("B", {5}, kInt);
1494
  BufHandle c("C", {5}, kInt);
1495
  VarHandle x("x", kInt);
1496
  VarHandle y("y", kInt);
1497

1498
  StmtPtr stmt = Block::make(
1499
      {Store::make(
1500
           b,
1501
           {y},
1502
           IfThenElse::make(
1503
               CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1504
               Add::make(Load::make(a, {x}), 1),
1505
               Add::make(Load::make(a, {x + 1}), 2))),
1506
       Store::make(
1507
           b,
1508
           {y + 1},
1509
           IfThenElse::make(
1510
               CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1511
               Add::make(Load::make(a, {x}), 1),
1512
               Add::make(Load::make(a, {x + 1}), 2)))});
1513

1514
  /*
1515
   * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1516
   * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1517
   */
1518

1519
  std::ostringstream before;
1520
  before << *stmt;
1521

1522
  // No change.
1523
  stmt = registerize(stmt);
1524

1525
  std::ostringstream after;
1526
  after << *stmt;
1527

1528
  ASSERT_EQ(before.str(), after.str());
1529
}
1530

1531
// Conditional unhiding also works for IfThenElse exprs.
1532
TEST(Registerizer, RegisterizerIfThenElseUnhidden) {
1533
  BufHandle a("A", {5}, kInt);
1534
  BufHandle b("B", {5}, kInt);
1535
  BufHandle c("C", {5}, kInt);
1536
  VarHandle x("x", kInt);
1537
  VarHandle y("y", kInt);
1538

1539
  StmtPtr stmt = Block::make({
1540
      Store::make(a, {x}, 0),
1541
      Store::make(
1542
          b,
1543
          {y},
1544
          IfThenElse::make(
1545
              CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1546
              Add::make(Load::make(a, {x}), 1),
1547
              Add::make(Load::make(a, {x + 1}), 2))),
1548
      Store::make(
1549
          b,
1550
          {y + 1},
1551
          IfThenElse::make(
1552
              CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1553
              Add::make(Load::make(a, {x}), 1),
1554
              Add::make(Load::make(a, {x + 1}), 2))),
1555
  });
1556

1557
  /*
1558
   * A[x] = 0;
1559
   * B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1560
   * B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1561
   */
1562

1563
  stmt = registerize(stmt);
1564

1565
  /*
1566
   * int A_1 = 0;
1567
   * B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1568
   * B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1569
   * A[x] = A_1;
1570
   */
1571

1572
  std::ostringstream oss;
1573
  oss << *stmt;
1574

1575
  const std::string& verification_pattern =
1576
      R"IR(
1577
# CHECK: int A_1 = 0;
1578
# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1579
# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1580
# CHECK: A[x] = A_1;)IR";
1581

1582
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1583
}
1584

1585
// Nested IfThenElse exprs can't promote to higher level scopes.
1586
TEST(Registerizer, RegisterizerIfThenElseNested) {
1587
  BufHandle a("A", {5}, kInt);
1588
  BufHandle b("B", {5}, kInt);
1589
  BufHandle c("C", {5}, kInt);
1590
  BufHandle d("D", {5}, kInt);
1591
  VarHandle x("x", kInt);
1592

1593
  StmtPtr stmt = Block::make({Store::make(
1594
      a,
1595
      {x},
1596
      IfThenElse::make(
1597
          CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1598
          IfThenElse::make(
1599
              CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
1600
              Load::make(d, {x}),
1601
              Load::make(b, {x})),
1602
          IfThenElse::make(
1603
              CompareSelect::make(x, 5, CompareSelectOperation::kEQ),
1604
              Load::make(c, {x}),
1605
              Load::make(d, {x}))))});
1606

1607
  /*
1608
   * A[x] = IfThenElse(x<3 ? 1 : 0,
1609
   *          IfThenElse(x==2 ? 1 : 0, D[x], B[x]),
1610
   *            IfThenElse(x==5 ? 1 : 0, C[x], D[x]));
1611
   */
1612

1613
  std::ostringstream before;
1614
  before << *stmt;
1615

1616
  // No change.
1617
  stmt = registerize(stmt);
1618

1619
  std::ostringstream after;
1620
  after << *stmt;
1621

1622
  ASSERT_EQ(before.str(), after.str());
1623
}
1624

1625
// Cannot registerize an access completely contained within an IfThenElse
1626
// branch, since it is not a Stmt and cannot hold variable definitions. We need
1627
// to check that we don't promote the initializer/finalizer to the enclosing
1628
// Block.
1629
TEST(Registerizer, RegisterizerIfThenElseInternal) {
1630
  // Making these floats so they don't get simplified to a single access.
1631
  BufHandle a("A", {5}, kFloat);
1632
  BufHandle b("B", {5}, kFloat);
1633
  VarHandle x("x", kInt);
1634

1635
  StmtPtr stmt = Block::make({Store::make(
1636
      a,
1637
      {x},
1638
      IfThenElse::make(
1639
          CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1640
          Add::make(Load::make(b, {x}), Load::make(b, {x})),
1641
          Load::make(b, {x})))});
1642

1643
  /*
1644
   * A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]);
1645
   */
1646

1647
  std::ostringstream before;
1648
  before << *stmt;
1649

1650
  // No change.
1651
  stmt = registerize(stmt);
1652

1653
  std::ostringstream after;
1654
  after << *stmt;
1655

1656
  ASSERT_EQ(before.str(), after.str());
1657

1658
  // If this was a Cond instead of an IfThenElse then we could registerize the
1659
  // two accesses to B[x] in the True branch.
1660

1661
  // Actually lets verify that.
1662

1663
  stmt = Block::make({Cond::make(
1664
      CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1665
      Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))),
1666
      Store::make(a, {x}, Load::make(b, {x})))});
1667

1668
  /*
1669
   * if (x<3 ? 1 : 0) {
1670
   *   A[x] = (B[x]) + (B[x]);
1671
   * } else {
1672
   *   A[x] = B[x];
1673
   * }
1674
   */
1675

1676
  stmt = registerize(stmt);
1677

1678
  /*
1679
   * if (x<3 ? 1 : 0) {
1680
   *   float B_1 = B[x];
1681
   *   A[x] = B_1 + B_1;
1682
   * } else {
1683
   *   A[x] = B[x];
1684
   * }
1685
   */
1686

1687
  std::ostringstream oss;
1688
  oss << *stmt;
1689

1690
  const std::string& verification_pattern =
1691
      R"IR(
1692
# CHECK-NOT: int
1693
# CHECK-NOT: float
1694
# CHECK: if (x<3
1695
# CHECK:   float B_1 =
1696
# CHECK:   A[x] = B_1 + B_1
1697
# CHECK: } else {
1698
# CHECK:   A[x] = B[x]
1699
# CHECK: }
1700
# CHECK-NOT: A[x]
1701
# CHECK-NOT: B[x])IR";
1702

1703
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1704
}
1705

1706
// Can registerize a load that occurs in the condition of an IfThenElse;
1707
TEST(Registerizer, RegisterizerIfThenElseCondition) {
1708
  BufHandle a("A", {5}, kInt);
1709
  BufHandle b("B", {5}, kInt);
1710
  BufHandle c("C", {5}, kInt);
1711
  VarHandle x("x", kInt);
1712

1713
  StmtPtr stmt = Block::make(
1714
      {Store::make(a, {x}, Load::make(a, {x})),
1715
       Store::make(
1716
           a,
1717
           {x},
1718
           IfThenElse::make(
1719
               CompareSelect::make(
1720
                   Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1721
               Load::make(b, {0}),
1722
               Load::make(c, {0})))});
1723

1724
  /*
1725
   * A[x] = A[x];       <---- just here so there are enough accesses to combine.
1726
   * A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]);
1727
   */
1728

1729
  stmt = registerize(stmt);
1730

1731
  /*
1732
   * int A_1 = A[x];
1733
   * A_1 = A_1;
1734
   * A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
1735
   * A[x] = A_1;
1736
   */
1737

1738
  std::ostringstream oss;
1739
  oss << *stmt;
1740

1741
  const std::string& verification_pattern =
1742
      R"IR(
1743
# CHECK: int A_1 = A[x];
1744
# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
1745
# CHECK: A[x] = A_1;)IR";
1746

1747
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1748
}
1749

1750
// Appearing in the condition of a Cond makes it visible to the enclosing scope,
1751
// and so we can registerize internal usages.
1752
TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) {
1753
  BufHandle a("A", {5}, kInt);
1754
  BufHandle b("B", {5}, kInt);
1755
  BufHandle c("C", {5}, kInt);
1756
  VarHandle x("x", kInt);
1757

1758
  StmtPtr stmt = Block::make({Store::make(
1759
      b,
1760
      {x},
1761
      IfThenElse::make(
1762
          CompareSelect::make(
1763
              Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1764
          Add::make(Load::make(a, {x}), 1),
1765
          Add::make(Load::make(a, {x}), 10)))});
1766

1767
  /*
1768
   * B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10);
1769
   */
1770

1771
  stmt = registerize(stmt);
1772

1773
  /*
1774
   * int A_1 = A[x];
1775
   * B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);
1776
   */
1777

1778
  std::ostringstream oss;
1779
  oss << *stmt;
1780

1781
  const std::string& verification_pattern =
1782
      R"IR(
1783
# CHECK: int A_1 = A[x];
1784
# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR";
1785

1786
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1787
}
1788

1789
// Cannot promote accesses internal to IfThenElse branches even if the enclosing
1790
// scope if conditional.
1791
TEST(Registerizer, RegisterizerConditionBranchOnly) {
1792
  BufHandle a("A", {5}, kInt);
1793
  VarHandle x("x", kInt);
1794
  StmtPtr stmt = Block::make({For::make(
1795
      x,
1796
      0,
1797
      10,
1798
      Block::make({
1799
          Cond::make(
1800
              CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1801
              Store::make(
1802
                  a,
1803
                  {x},
1804
                  IfThenElse::make(
1805
                      CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1806
                      Add::make(Load::make(a, {x}), x),
1807
                      Add::make(Load::make(a, {x - 5}), x))),
1808
              Store::make(
1809
                  a,
1810
                  {x - 5},
1811
                  IfThenElse::make(
1812
                      CompareSelect::make(x, 5, CompareSelectOperation::kLT),
1813
                      Add::make(Load::make(a, {x}), x),
1814
                      Add::make(Load::make(a, {x - 5}), x)))),
1815
      }))});
1816
  stmt = IRSimplifier::simplify(stmt);
1817

1818
  std::ostringstream before;
1819
  before << *stmt;
1820

1821
  /* for (int x = 0; x < 10; x++) {
1822
   *   if (x<5 ? 1 : 0) {
1823
   *     A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
1824
   *   } else {
1825
   *     A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
1826
   *   }
1827
   * }
1828
   */
1829

1830
  // No change.
1831
  stmt = registerize(stmt);
1832

1833
  std::ostringstream after;
1834
  after << *stmt;
1835

1836
  ASSERT_EQ(before.str(), after.str());
1837
}
1838

1839
// We can registerize an IfThenElse that appears in the condition branch of a
1840
// Cond. This is a weird but valid thing to do.
1841
TEST(Registerizer, RegisterizerCondIfThenElse) {
1842
  BufHandle a("A", {5}, kInt);
1843
  BufHandle b("B", {5}, kInt);
1844
  BufHandle c("C", {5}, kInt);
1845
  VarHandle x("x", kInt);
1846

1847
  StmtPtr stmt = Block::make({Cond::make(
1848
      CompareSelect::make(
1849
          IfThenElse::make(
1850
              CompareSelect::make(
1851
                  Load::make(a, {x}), 5, CompareSelectOperation::kLT),
1852
              Load::make(a, {x}),
1853
              Load::make(b, {x})),
1854
          x,
1855
          CompareSelectOperation::kEQ),
1856
      Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),
1857
      nullptr)});
1858

1859
  /*
1860
   * if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) {
1861
   *   C[x] = (C[x]) + 1;
1862
   * }
1863
   */
1864

1865
  stmt = registerize(stmt);
1866

1867
  // access to A can be registerized, but not B or C
1868

1869
  /*
1870
   * int A_1 = A[x];
1871
   * if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) {
1872
   *   C[x] = (C[x]) + 1;
1873
   * }
1874
   */
1875

1876
  std::ostringstream oss;
1877
  oss << *stmt;
1878

1879
  const std::string& verification_pattern =
1880
      R"IR(
1881
# CHECK: int A_1 = A[x];
1882
# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]
1883
# CHECK:   C[x] = (C[x]) + 1;)IR";
1884

1885
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1886
}
1887

1888
// Can registerize a conditional access in the RHS of a store unhidden by it's
1889
// LHS, and hoist it out of a loop.
1890
TEST(Registerizer, RegisterizerIfThenElseLoop) {
1891
  BufHandle a("A", {5}, kInt);
1892
  BufHandle b("B", {5}, kInt);
1893
  VarHandle x("x", kInt);
1894
  VarHandle y("y", kInt);
1895

1896
  StmtPtr stmt = For::make(
1897
      y,
1898
      0,
1899
      10,
1900
      Store::make(
1901
          a,
1902
          {x},
1903
          IfThenElse::make(
1904
              CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1905
              Load::make(a, {x}),
1906
              Load::make(b, {y}))));
1907

1908
  /*
1909
   * for (int y = 0; y < 10; y++) {
1910
   *   A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]);
1911
   * }
1912
   */
1913

1914
  stmt = registerize(stmt);
1915

1916
  /*
1917
   * int A_1 = A[x];
1918
   * for (int y = 0; y < 10; y++) {
1919
   *   A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
1920
   * }
1921
   * A[x] = A_1;
1922
   */
1923

1924
  std::ostringstream oss;
1925
  oss << *stmt;
1926

1927
  const std::string& verification_pattern =
1928
      R"IR(
1929
# CHECK: int A_1 = A[x];
1930
# CHECK: for (
1931
# CHECK:   A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
1932
# CHECK: }
1933
# CHECK: A[x] = A_1;)IR";
1934

1935
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1936
}
1937

1938
// Cannot registerize if the RHS overlaps the access creating visibility.
1939
TEST(Registerizer, RegisterizerIfThenElseLoopCut) {
1940
  BufHandle a("A", {5}, kInt);
1941
  BufHandle b("B", {5}, kInt);
1942
  VarHandle x("x", kInt);
1943
  VarHandle y("y", kInt);
1944

1945
  StmtPtr stmt = Block::make({For::make(
1946
      y,
1947
      0,
1948
      10,
1949
      Store::make(
1950
          a,
1951
          {x},
1952
          IfThenElse::make(
1953
              CompareSelect::make(x, 3, CompareSelectOperation::kLT),
1954
              Load::make(a, {x}),
1955
              Load::make(a, {y}))))});
1956

1957
  /*
1958
   * for (int y = 0; y < 10; y++) {
1959
   *   A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]);
1960
   * }
1961
   */
1962

1963
  std::ostringstream before;
1964
  before << *stmt;
1965

1966
  // No change.
1967
  stmt = registerize(stmt);
1968

1969
  std::ostringstream after;
1970
  after << *stmt;
1971

1972
  ASSERT_EQ(before.str(), after.str());
1973
}
1974

1975
// Simple case where an access is cut by an overlapping access later in the
1976
// program, we can registerize up until the overlap.
1977
TEST(Registerizer, RegisterizerPartialAfter) {
1978
  BufHandle a("A", {1}, kInt);
1979
  VarHandle x("x", kInt);
1980
  StmtPtr stmt = Block::make(
1981
      {Store::make(a, {0}, 0),
1982
       For::make(
1983
           x,
1984
           0,
1985
           10,
1986
           Block::make(
1987
               {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})),
1988
       For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))});
1989

1990
  /*
1991
   * A[0] = 0;
1992
   * for (int x = 0; x < 10; x++) {
1993
   *   A[0] = (A[0]) + x;
1994
   * }
1995
   * for (int x = 1; x < 10; x++) {
1996
   *   A[x] = A[x - 1];
1997
   * }
1998
   */
1999

2000
  stmt = registerize(stmt);
2001

2002
  /*
2003
   * int A_1 = 0;
2004
   * for (int x = 0; x < 10; x++) {
2005
   *   A_1 = A_1 + x;
2006
   * }
2007
   * A[0] = A_1;
2008
   * for (int x = 1; x < 10; x++) {
2009
   *   A[x] = A[x - 1];
2010
   * }
2011
   */
2012

2013
  std::ostringstream oss;
2014
  oss << *stmt;
2015

2016
  const std::string& verification_pattern =
2017
      R"IR(
2018
# CHECK: int A_1 = 0;
2019
# CHECK: for (
2020
# CHECK:   A_1 = A_1 + x;
2021
# CHECK: }
2022
# CHECK: A[0] = A_1;
2023
# CHECK: for (
2024
# CHECK:   A[x] = A[x - 1];
2025
# CHECK: }
2026
# CHECK-NOT: A)IR";
2027

2028
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2029
}
2030

2031
// We can registerize an access which overlaps a previous access, the
2032
// initializer must be inserted after the previous access.
2033
TEST(Registerizer, RegisterizerPartialBefore) {
2034
  BufHandle a("A", {1}, kInt);
2035
  VarHandle x("x", kInt);
2036
  StmtPtr stmt = Block::make(
2037
      {For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),
2038
       Store::make(a, {0}, 0),
2039
       For::make(
2040
           x,
2041
           0,
2042
           10,
2043
           Block::make(
2044
               {Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});
2045

2046
  /*
2047
   * for (int x = 1; x < 10; x++) {
2048
   *   A[x] = A[x - 1];
2049
   * }
2050
   * A[0] = 0;
2051
   * for (int x = 0; x < 10; x++) {
2052
   *   A[0] = (A[0]) + x;
2053
   * }
2054
   */
2055

2056
  stmt = registerize(stmt);
2057

2058
  /*
2059
   * for (int x = 1; x < 10; x++) {
2060
   *   A[x] = A[x - 1];
2061
   * }
2062
   * int A_1 = 0;
2063
   * for (int x = 0; x < 10; x++) {
2064
   *   A_1 = A_1 + x;
2065
   * }
2066
   * A[0] = A_1;
2067
   */
2068

2069
  std::ostringstream oss;
2070
  oss << *stmt;
2071

2072
  const std::string& verification_pattern =
2073
      R"IR(
2074
# CHECK-NOT: int
2075
# CHECK: for (
2076
# CHECK:   A[x] = A[x - 1];
2077
# CHECK: }
2078
# CHECK: int A_1 = 0;
2079
# CHECK: for (
2080
# CHECK:   A_1 = A_1 + x;
2081
# CHECK: }
2082
# CHECK: A[0] = A_1;)IR";
2083

2084
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2085
}
2086

2087
// The combination of the previous two tests, an access is cut by an overlapping
2088
// access in both directions.
2089
TEST(Registerizer, RegisterizerPartialInside) {
2090
  BufHandle a("A", {1}, kInt);
2091
  VarHandle x1("x1", kInt);
2092
  VarHandle x2("x2", kInt);
2093
  VarHandle x3("x3", kInt);
2094
  StmtPtr stmt = Block::make(
2095
      {Store::make(a, {0}, 2),
2096
       For::make(
2097
           x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))),
2098
       For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))),
2099
       For::make(
2100
           x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))});
2101

2102
  /*
2103
   * A[0] = 2;
2104
   * for (int x1 = 0; x1 < 10; x1++) {
2105
   *   A[0] = (A[0]) + x1;
2106
   * }
2107
   * for (int x2 = 1; x2 < 10; x2++) {
2108
   *   A[x2] = A[x2 - 1];
2109
   * }
2110
   * for (int x3 = 0; x3 < 10; x3++) {
2111
   *   A[0] = (A[0]) + x3;
2112
   * }
2113
   */
2114

2115
  stmt = registerize(stmt);
2116

2117
  /*
2118
   * int A_1 = 2;
2119
   * for (int x1 = 0; x1 < 10; x1++) {
2120
   *   A_1 = A_1 + x1;
2121
   * }
2122
   * A[0] = A_1;
2123
   * for (int x2 = 1; x2 < 10; x2++) {
2124
   *   A[x2] = A[x2 - 1];
2125
   * }
2126
   * int A_2 = A[0];
2127
   * for (int x3 = 0; x3 < 10; x3++) {
2128
   *   A_2 = A_2 + x3;
2129
   * }
2130
   * A[0] = A_2;
2131
   */
2132

2133
  std::ostringstream oss;
2134
  oss << *stmt;
2135

2136
  const std::string& verification_pattern =
2137
      R"IR(
2138
# CHECK: int A_1 = 2;
2139
# CHECK: for (
2140
# CHECK:   A_1 = A_1 + x1;
2141
# CHECK: }
2142
# CHECK: A[0] = A_1;
2143
# CHECK: for (
2144
# CHECK:   A[x2] =
2145
# CHECK: }
2146
# CHECK: int A_2 = A[0];
2147
# CHECK: for (
2148
# CHECK:   A_2 = A_2 + x3;
2149
# CHECK: }
2150
# CHECK: A[0] = A_2;)IR";
2151

2152
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2153
}
2154

2155
// An element could be registerized program wide but is cut by a conditional
2156
// access, we should break this into two scalars and write back to the buffer
2157
// before the condition.
2158
TEST(Registerizer, RegisterizerPartialCondition) {
2159
  BufHandle a("A", {1}, kInt);
2160
  VarHandle x("x", kInt);
2161
  StmtPtr stmt = Block::make(
2162
      {Store::make(a, {0}, 2),
2163
       For::make(
2164
           x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))),
2165
       Cond::make(
2166
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2167
           Store::make(a, {x}, Load::make(a, {x - 1})),
2168
           nullptr),
2169
       For::make(
2170
           x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))});
2171

2172
  /*
2173
   * A[0] = 2;
2174
   * for (int x = 0; x < 10; x++) {
2175
   *   A[0] = (A[0]) + x;
2176
   * }
2177
   * if (x<5 ? 1 : 0) {
2178
   *   A[x] = A[x - 1];
2179
   * }
2180
   * for (int x = 0; x < 10; x++) {
2181
   *   A[0] = (A[0]) + x;
2182
   * }
2183
   */
2184

2185
  stmt = registerize(stmt);
2186

2187
  /*
2188
   * int A_1 = 2;
2189
   * for (int x = 0; x < 10; x++) {
2190
   *   A_1 = A_1 + x;
2191
   * }
2192
   * A[0] = A_1;
2193
   * if (x<5 ? 1 : 0) {
2194
   *   A[x] = A[x - 1];
2195
   * }
2196
   * int A_2 = A[0];
2197
   * for (int x = 0; x < 10; x++) {
2198
   *   A_2 = A_2 + x;
2199
   * }
2200
   * A[0] = A_2;
2201
   */
2202

2203
  std::ostringstream oss;
2204
  oss << *stmt;
2205

2206
  const std::string& verification_pattern =
2207
      R"IR(
2208
# CHECK: int A_1 = 2;
2209
# CHECK: for (
2210
# CHECK:   A_1 = A_1 + x;
2211
# CHECK: }
2212
# CHECK: A[0] = A_1;
2213
# CHECK: if (
2214
# CHECK:   A[x] =
2215
# CHECK: }
2216
# CHECK: int A_2 = A[0];
2217
# CHECK: for (
2218
# CHECK:   A_2 = A_2 + x;
2219
# CHECK: }
2220
# CHECK: A[0] = A_2;)IR";
2221

2222
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2223
}
2224

2225
// Tests case where an access is cut by an internal conditional access which
2226
// itself is registerized.
2227
TEST(Registerizer, RegisterizerPartialConditionInternalCut) {
2228
  BufHandle a("A", {1}, kInt);
2229
  VarHandle x("x", kInt);
2230
  StmtPtr stmt = Block::make(
2231
      {Store::make(a, {0}, 1),
2232
       Store::make(a, {0}, 3),
2233
       Cond::make(
2234
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2235
           Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
2236
           nullptr),
2237
       Store::make(a, {0}, 4),
2238
       Store::make(a, {0}, 6)});
2239

2240
  /*
2241
   * A[0] = 1;
2242
   * A[0] = 3;
2243
   * if (x<5 ? 1 : 0) {
2244
   *   A[x] = 1;
2245
   *   A[x] = 3;
2246
   * }
2247
   * A[0] = 4;
2248
   * A[0] = 6;
2249
   */
2250

2251
  stmt = registerize(stmt);
2252

2253
  /*
2254
   * int A_1 = 1;
2255
   * A_1 = 3;
2256
   * A[0] = A_1;
2257
   * if (x<5 ? 1 : 0) {
2258
   *   int A_2 = 1;
2259
   *   A_2 = 3;
2260
   *   A[x] = A_2;
2261
   * }
2262
   * int A_3 = 4;
2263
   * A_3 = 6;
2264
   * A[0] = A_3;
2265
   */
2266

2267
  std::ostringstream oss;
2268
  oss << *stmt;
2269

2270
  const std::string& verification_pattern =
2271
      R"IR(
2272
# CHECK: int A_1 = 1;
2273
# CHECK: A_1 = 3
2274
# CHECK: A[0] = A_1;
2275
# CHECK: if (
2276
# CHECK:   int A_2 = 1;
2277
# CHECK:   A_2 = 3;
2278
# CHECK:   A[x] = A_2;
2279
# CHECK: }
2280
# CHECK: int A_3 = 4;
2281
# CHECK: A_3 = 6;
2282
# CHECK: A[0] = A_3;)IR";
2283

2284
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2285
}
2286

2287
// First statement in condition closes outer access, but can be registerized
2288
// with later statements.
2289
TEST(Registerizer, RegisterizerPartialConditionInternalStart) {
2290
  BufHandle a("A", {1}, kInt);
2291
  VarHandle x("x", kInt);
2292
  StmtPtr stmt = Block::make(
2293
      {Store::make(a, {0}, 1),
2294
       Store::make(a, {0}, 3),
2295
       Cond::make(
2296
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2297
           Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),
2298
           nullptr),
2299
       Store::make(a, {x}, 4),
2300
       Store::make(a, {x}, 6)});
2301

2302
  /*
2303
   * A[0] = 1;
2304
   * A[0] = 3;
2305
   * if (x<5 ? 1 : 0) {
2306
   *   A[x] = 1;
2307
   *   A[x] = 3;
2308
   * }
2309
   * A[x] = 4;
2310
   * A[x] = 6;
2311
   */
2312

2313
  stmt = registerize(stmt);
2314

2315
  /*
2316
   * int A_1 = 1;
2317
   * A_1 = 3;
2318
   * A[0] = A_1;
2319
   * int A_2 = A[x];    <--- must read from the input here.
2320
   * if (x<5 ? 1 : 0) {
2321
   *   A_2 = 1;
2322
   *   A_2 = 3;
2323
   * }
2324
   * A_2 = 4;
2325
   * A_2 = 6;
2326
   * A[x] = A_2;
2327
   */
2328

2329
  // TODO: I suppose we could refactor with a conditional initializer?
2330

2331
  std::ostringstream oss;
2332
  oss << *stmt;
2333

2334
  const std::string& verification_pattern =
2335
      R"IR(
2336
# CHECK: int A_1 = 1;
2337
# CHECK: A_1 = 3
2338
# CHECK: A[0] = A_1;
2339
# CHECK: int A_2 = A[x];
2340
# CHECK: if (
2341
# CHECK:   A_2 = 1;
2342
# CHECK:   A_2 = 3;
2343
# CHECK: }
2344
# CHECK: A_2 = 4;
2345
# CHECK: A_2 = 6;
2346
# CHECK: A[x] = A_2;)IR";
2347

2348
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2349
}
2350

2351
// An access cuts two open overlaps and creates four scalar variables.
2352
TEST(Registerizer, RegisterizerPartialOverlapsTwo) {
2353
  BufHandle a("A", {1}, kInt);
2354
  VarHandle x("x", kInt);
2355
  StmtPtr stmt = Block::make(
2356
      {Store::make(a, {1}, Load::make(a, {0})),
2357
       Store::make(a, {0}, Load::make(a, {1})),
2358
       Store::make(a, {0}, Load::make(a, {1})),
2359
       For::make(x, 1, 10, Store::make(a, {x}, x)),
2360
       Store::make(a, {1}, Load::make(a, {0})),
2361
       Store::make(a, {0}, Load::make(a, {1})),
2362
       Store::make(a, {0}, Load::make(a, {1}))});
2363

2364
  /*
2365
   * A[1] = A[0];
2366
   * A[0] = A[1];
2367
   * A[0] = A[1];
2368
   * for (int x = 1; x < 10; x++) {
2369
   *   A[x] = x;
2370
   * }
2371
   * A[1] = A[0];
2372
   * A[0] = A[1];
2373
   * A[0] = A[1];
2374
   */
2375

2376
  stmt = registerize(stmt);
2377

2378
  /*
2379
   * int A_1 = A[0];
2380
   * int A_2 = A_1;
2381
   * A_1 = A_2;
2382
   * A_1 = A_2;
2383
   * A[1] = A_2;
2384
   * A[0] = A_1;
2385
   * for (int x = 1; x < 10; x++) {
2386
   *   A[x] = x;
2387
   * }
2388
   * int A_3 = A[0];
2389
   * int A_4 = A_3;
2390
   * A_3 = A_4;
2391
   * A_3 = A_4;
2392
   * A[1] = A_4;
2393
   * A[0] = A_3;
2394
   */
2395

2396
  std::ostringstream oss;
2397
  oss << *stmt;
2398

2399
  const std::string& verification_pattern =
2400
      R"IR(
2401
# CHECK: int A_1 = A[0];
2402
# CHECK: int A_2 = A_1;
2403
# CHECK: A_1 = A_2;
2404
# CHECK: A_1 = A_2;
2405
# CHECK: A[1] = A_2;
2406
# CHECK: A[0] = A_1;
2407
# CHECK: for (
2408
# CHECK:   A[x] = x;
2409
# CHECK: }
2410
# CHECK: int A_3 = A[0];
2411
# CHECK: int A_4 = A_3;
2412
# CHECK: A_3 = A_4;
2413
# CHECK: A_3 = A_4;
2414
# CHECK: A[1] = A_4;
2415
# CHECK: A[0] = A_3;)IR";
2416

2417
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2418
}
2419

2420
// Nested blocks will automatically be flattened and do not provent
2421
// registerization of enclosed accesses.
2422
TEST(Registerizer, RegisterizerNestedBlocks) {
2423
  BufHandle a("A", {1}, kInt);
2424
  VarHandle x("x", kInt);
2425
  StmtPtr stmt = Block::make(
2426
      // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
2427
      {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2428
       Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}),
2429
       Block::make(
2430
           {Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)),
2431
            Block::make(
2432
                {Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})});
2433

2434
  /*
2435
   * A[0] = (A[0]) + 1;
2436
   * {
2437
   *   A[0] = (A[0]) + 2;
2438
   * }
2439
   * {
2440
   *   A[0] = (A[0]) + 3;
2441
   *   {
2442
   *     A[0] = (A[0]) + 4;
2443
   *   }
2444
   * }
2445
   */
2446

2447
  stmt = registerize(stmt);
2448

2449
  /*
2450
   * int A_1 = A[0];
2451
   * A_1 = A_1 + 1;
2452
   * A_1 = A_1 + 2;
2453
   * A_1 = A_1 + 3;
2454
   * A_1 = A_1 + 4;
2455
   * A[0] = A_1;
2456
   */
2457

2458
  std::ostringstream oss;
2459
  oss << *stmt;
2460

2461
  const std::string& verification_pattern =
2462
      R"IR(
2463
# CHECK: int A_1 = A[0];
2464
# CHECK: A_1 = A_1 + 1;
2465
# CHECK: A_1 = A_1 + 2;
2466
# CHECK: A_1 = A_1 + 3;
2467
# CHECK: A_1 = A_1 + 4;
2468
# CHECK: A[0] = A_1;)IR";
2469

2470
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2471
}
2472

2473
// The access can be registerized internally to a condition, but must ensure
2474
// that both initializer and finalizer are within the same condition.
2475
TEST(Registerizer, RegisterizerNestedConditions) {
2476
  BufHandle a("A", {1}, kInt);
2477
  VarHandle x("x", kInt);
2478
  StmtPtr stmt = Block::make({Cond::make(
2479
      CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2480
      Block::make(
2481
          {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2482
           Cond::make(
2483
               CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2484
               Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2485
               nullptr)}),
2486
      nullptr)});
2487

2488
  /*
2489
   * if (x<5 ? 1 : 0) {
2490
   *   A[0] = (A[0]) + 1;
2491
   *   if (x==2 ? 1 : 0) {
2492
   *
2493
   *     A[0] = (A[0]) + 1;
2494
   *   }
2495
   * }
2496
   */
2497

2498
  stmt = registerize(stmt);
2499

2500
  /*
2501
   * if (x<5 ? 1 : 0) {
2502
   *   int A_1 = A[0];
2503
   *   A_1 = A_1 + 1;
2504
   *   if (x==2 ? 1 : 0) {
2505
   *     A_1 = A_1 + 1;
2506
   *   }
2507
   * A[0] = A_1;
2508
   * }
2509
   */
2510

2511
  std::ostringstream oss;
2512
  oss << *stmt;
2513

2514
  const std::string& verification_pattern =
2515
      R"IR(
2516
# CHECK: if (x<5
2517
# CHECK:   int A_1 = A[0];
2518
# CHECK:   A_1 = A_1 + 1;
2519
# CHECK:   if (x==2
2520
# CHECK:     A_1 = A_1 + 1;
2521
# CHECK:   }
2522
# CHECK: A[0] = A_1;
2523
# CHECK: })IR";
2524

2525
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2526
}
2527

2528
// If an access exists outside the scope of the condition then we can lift
2529
// nested conditional usages into the same scalar.
2530
TEST(Registerizer, RegisterizerNestedConditionsUnhidden) {
2531
  BufHandle a("A", {1}, kInt);
2532
  VarHandle x("x", kInt);
2533
  StmtPtr stmt = Block::make(
2534
      {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2535
       Cond::make(
2536
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2537
           Block::make(
2538
               {Store::make(a, {1}, 1),
2539
                Cond::make(
2540
                    CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2541
                    Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2542
                    nullptr)}),
2543
           nullptr)});
2544

2545
  /*
2546
   * A[0] = (A[0]) + 1;
2547
   * if (x<5 ? 1 : 0) {
2548
   *   A[1] = 1;
2549
   *   if (x==2 ? 1 : 0) {
2550
   *     A[0] = (A[0]) + 1;
2551
   *   }
2552
   * }
2553
   */
2554

2555
  stmt = registerize(stmt);
2556

2557
  /*
2558
   * int A_1 = A[0];
2559
   * A_1 = A_1 + 1;
2560
   * if (x<5 ? 1 : 0) {
2561
   *   A[1] = 1;
2562
   *   if (x==2 ? 1 : 0) {
2563
   *     A_1 = A_1 + 1;
2564
   *   }
2565
   * }
2566
   * A[0] = A_1;
2567
   */
2568

2569
  std::ostringstream oss;
2570
  oss << *stmt;
2571

2572
  const std::string& verification_pattern =
2573
      R"IR(
2574
# CHECK: int A_1 = A[0];
2575
# CHECK: A_1 = A_1 + 1;
2576
# CHECK: if (x<5
2577
# CHECK:   A[1] = 1;
2578
# CHECK:   if (x==2
2579
# CHECK:     A_1 = A_1 + 1;
2580
# CHECK: A[0] = A_1;)IR";
2581

2582
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2583
}
2584

2585
TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) {
2586
  BufHandle a("A", {1}, kInt);
2587
  VarHandle x("x", kInt);
2588
  StmtPtr stmt = Block::make(
2589
      {Cond::make(
2590
           CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2591
           Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2592
           nullptr),
2593
       Cond::make(
2594
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2595
           Block::make({Cond::make(
2596
               CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2597
               Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2598
               nullptr)}),
2599
           nullptr)});
2600

2601
  /*
2602
   * if (x==2 ? 1 : 0) {
2603
   *   A[0] = (A[0]) + 1;
2604
   * }
2605
   * if (x<5 ? 1 : 0) {
2606
   *   if (x==2 ? 1 : 0) {
2607
   *     A[0] = (A[0]) + 1;
2608
   *   }
2609
   * }
2610
   */
2611

2612
  std::ostringstream before;
2613
  before << *stmt;
2614

2615
  // No change.
2616
  stmt = registerize(stmt);
2617

2618
  std::ostringstream after;
2619
  after << *stmt;
2620

2621
  ASSERT_EQ(before.str(), after.str());
2622

2623
  // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2624
  stmt = registerize(stmt);
2625
}
2626

2627
TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) {
2628
  BufHandle a("A", {1}, kInt);
2629
  VarHandle x("x", kInt);
2630
  StmtPtr stmt = Block::make(
2631
      {Cond::make(
2632
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2633
           Block::make({Cond::make(
2634
               CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2635
               Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2636
               nullptr)}),
2637
           nullptr),
2638
       Cond::make(
2639
           CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2640
           Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2641
           nullptr)});
2642

2643
  /*
2644
   * if (x<5 ? 1 : 0) {
2645
   *   if (x==2 ? 1 : 0) {
2646
   *     A[0] = (A[0]) + 1;
2647
   *   }
2648
   * }
2649
   * if (x==2 ? 1 : 0) {
2650
   *   A[0] = (A[0]) + 1;
2651
   * }
2652
   */
2653

2654
  std::ostringstream before;
2655
  before << *stmt;
2656

2657
  // No change.
2658
  stmt = registerize(stmt);
2659

2660
  std::ostringstream after;
2661
  after << *stmt;
2662

2663
  ASSERT_EQ(before.str(), after.str());
2664

2665
  // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
2666
  stmt = registerize(stmt);
2667
}
2668

2669
// If an access is cut by another access internal to a condition block, it still
2670
// cuts the access.
2671
TEST(Registerizer, RegisterizerNestedConditionsCut) {
2672
  BufHandle a("A", {1}, kInt);
2673
  VarHandle x("x", kInt);
2674
  StmtPtr stmt = Block::make(
2675
      {Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2676
       Cond::make(
2677
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
2678
           Block::make(
2679
               {Store::make(a, {x}, 1),
2680
                Cond::make(
2681
                    CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2682
                    Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2683
                    nullptr)}),
2684
           nullptr)});
2685

2686
  /*
2687
   * A[0] = (A[0]) + 1;
2688
   * if (x<5 ? 1 : 0) {
2689
   *   A[x] = 1;
2690
   *   if (x==2 ? 1 : 0) {
2691
   *
2692
   *     A[0] = (A[0]) + 1;
2693
   *   }
2694
   * }
2695
   */
2696

2697
  std::ostringstream before;
2698
  before << *stmt;
2699

2700
  // No change.
2701
  stmt = registerize(stmt);
2702

2703
  std::ostringstream after;
2704
  after << *stmt;
2705

2706
  ASSERT_EQ(before.str(), after.str());
2707
}
2708

2709
TEST(Registerizer, RegisterizerNestedConditionLoopHidden) {
2710
  BufHandle a("A", {10}, kInt);
2711
  BufHandle b("B", {10}, kInt);
2712
  VarHandle x("x", kInt);
2713
  StmtPtr stmt = Block::make(
2714
      {Cond::make(
2715
           CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2716
           Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2717
           nullptr),
2718
       For::make(
2719
           x,
2720
           0,
2721
           10,
2722
           Block::make(
2723
               {Store::make(b, {x}, 0),
2724
                Cond::make(
2725
                    CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2726
                    Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),
2727
                    nullptr)}))});
2728

2729
  /*
2730
   * if (x==2 ? 1 : 0) {
2731
   *   A[0] = (A[0]) + 1;
2732
   * }
2733
   * for (int x = 0; x < 10; x++) {
2734
   *   B[x] = 0;     <-- this is only here to prevent Loop/Cond reordering.
2735
   *   if (x==2 ? 1 : 0) {
2736
   *     A[0] = (A[0]) + 1;
2737
   *   }
2738
   * }
2739
   */
2740

2741
  std::ostringstream before;
2742
  before << *stmt;
2743

2744
  // No change.
2745
  stmt = registerize(stmt);
2746

2747
  std::ostringstream after;
2748
  after << *stmt;
2749

2750
  ASSERT_EQ(before.str(), after.str());
2751
}
2752

2753
// Three loops and four element regions, three of which should be registerized
2754
// at different levels of the IR.
2755
TEST(Registerizer, RegisterizerNestedConditionThreeDeep) {
2756
  BufHandle a("A", {10}, kInt);
2757
  BufHandle b("B", {10}, kInt);
2758
  VarHandle x("x", kInt);
2759
  StmtPtr stmt = Block::make(
2760
      {Store::make(a, {4}, 0),
2761
       Cond::make(
2762
           CompareSelect::make(x, 2, CompareSelectOperation::kGT),
2763
           Cond::make(
2764
               CompareSelect::make(x, 3, CompareSelectOperation::kGT),
2765
               Block::make({
2766
                   Cond::make(
2767
                       CompareSelect::make(x, 4, CompareSelectOperation::kGT),
2768
                       Block::make({
2769
                           Store::make(
2770
                               a, {1}, Add::make(Load::make(a, {1}), 1)),
2771
                           Store::make(
2772
                               a, {2}, Add::make(Load::make(a, {2}), 1)),
2773
                           Store::make(
2774
                               a, {3}, Add::make(Load::make(a, {3}), 1)),
2775
                           Store::make(
2776
                               a, {4}, Add::make(Load::make(a, {4}), 1)),
2777
                           Store::make(
2778
                               a, {1}, Add::make(Load::make(a, {1}), 1)),
2779
                       }),
2780
                       nullptr),
2781
                   Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)),
2782
               }),
2783
               nullptr),
2784
           nullptr)});
2785

2786
  /*
2787
   * A[4] = 0;
2788
   * if (x>2 ? 1 : 0) {
2789
   *   if (x>3 ? 1 : 0) {
2790
   *     if (x>4 ? 1 : 0) {
2791
   *       A[1] = (A[1]) + 1;
2792
   *       A[2] = (A[2]) + 1;
2793
   *       A[3] = (A[3]) + 1;
2794
   *       A[4] = (A[4]) + 1;
2795
   *       A[1] = (A[1]) + 1;
2796
   *     }
2797
   *     A[2] = (A[2]) + 1;
2798
   *   }
2799
   * }
2800
   */
2801

2802
  stmt = registerize(stmt);
2803

2804
  /*
2805
   * int A_1 = 0;
2806
   * if (x>2 ? 1 : 0) {
2807
   *   if (x>3 ? 1 : 0) {
2808
   *     int A_3 = A[2];
2809
   *     if (x>4 ? 1 : 0) {
2810
   *       int A_2 = A[1];
2811
   *       A_2 = A_2 + 1;
2812
   *       A_3 = A_3 + 1;
2813
   *       A[3] = (A[3]) + 1;
2814
   *       A_1 = A_1 + 1;
2815
   *       A_2 = A_2 + 1;
2816
   *       A[1] = A_2;
2817
   *     }
2818
   *     A_3 = A_3 + 1;
2819
   *     A[2] = A_3;
2820
   *   }
2821
   * }
2822
   * A[4] = A_1;
2823
   */
2824

2825
  std::ostringstream oss;
2826
  oss << *stmt;
2827

2828
  const std::string& verification_pattern =
2829
      R"IR(
2830
# CHECK: int A_1 = 0;
2831
# CHECK: if (x>2 ? 1 : 0) {
2832
# CHECK:   if (x>3 ? 1 : 0) {
2833
# CHECK:     int A_3 = A[2];
2834
# CHECK:     if (x>4 ? 1 : 0) {
2835
# CHECK:       int A_2 = A[1];
2836
# CHECK:       A_2 = A_2 + 1;
2837
# CHECK:       A_3 = A_3 + 1;
2838
# CHECK:       A[3] = (A[3]) + 1;
2839
# CHECK:       A_1 = A_1 + 1;
2840
# CHECK:       A_2 = A_2 + 1;
2841
# CHECK:       A[1] = A_2;
2842
# CHECK:     }
2843
# CHECK:     A_3 = A_3 + 1;
2844
# CHECK:     A[2] = A_3;
2845
# CHECK:   }
2846
# CHECK: }
2847
# CHECK: A[4] = A_1;)IR";
2848

2849
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2850
}
2851

2852
// Can replace a simple scalar access with a local variable even when that
2853
// variable is an outer loop var.
2854
TEST(Registerizer, RegisterizerNestedLoopSimple) {
2855
  BufHandle a("A", {1}, kInt);
2856
  VarHandle x("x", kInt);
2857
  VarHandle y("y", kInt);
2858
  StmtPtr stmt = Block::make({For::make(
2859
      y,
2860
      0,
2861
      10,
2862
      For::make(
2863
          x,
2864
          0,
2865
          10,
2866
          Block::make(
2867
              {Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))});
2868

2869
  /*
2870
   * for (int y = 0; y < 10; y++) {
2871
   *   for (int x = 0; x < 10; x++) {
2872
   *     A[y] = (A[y]) + x;
2873
   *   }
2874
   * }
2875
   */
2876

2877
  stmt = registerize(stmt);
2878

2879
  /*
2880
   * for (int y = 0; y < 10; y++) {
2881
   *   int A_1 = A[y];
2882
   *   for (int x = 0; x < 10; x++) {
2883
   *     A_1 = A_1 + x;
2884
   *   }
2885
   * A[y] = A_1;
2886
   * }
2887
   */
2888

2889
  std::ostringstream oss;
2890
  oss << *stmt;
2891

2892
  const std::string& verification_pattern =
2893
      R"IR(
2894
# CHECK: for (int y
2895
# CHECK:   int A_1 = A[y];
2896
# CHECK:   for (int x
2897
# CHECK:     A_1 = A_1 + x;
2898
# CHECK:   }
2899
# CHECK:   A[y] = A_1;
2900
# CHECK: })IR";
2901

2902
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2903
}
2904

2905
// Test the positive case of the hiddenAccess split, where an internal
2906
// conditional access can be hoisted up through a loop to match an existing
2907
// access in a higher scope and the two can be registerized.
2908
TEST(Registerizer, RegisterizerHiddenAccessYes) {
2909
  BufHandle a("A", {10}, kInt);
2910
  BufHandle b("B", {10}, kInt);
2911
  VarHandle x("x", kInt);
2912
  VarHandle y("y", kInt);
2913
  StmtPtr stmt = Block::make({Cond::make(
2914
      CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2915
      Block::make(
2916
          {Store::make(a, {0}, 0),
2917
           For::make(
2918
               x,
2919
               0,
2920
               10,
2921
               Block::make(
2922
                   {Store::make(b, {x}, 0),
2923
                    // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
2924
                    Cond::make(
2925
                        CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
2926
                        For::make(
2927
                            y,
2928
                            0,
2929
                            10,
2930
                            Store::make(
2931
                                a, {0}, Add::make(Load::make(a, {0}), 1))),
2932
                        nullptr)}))}),
2933
      nullptr)});
2934

2935
  /*
2936
   * if (x==2 ? 1 : 0) {
2937
   *   A[0] = 0;
2938
   *   for (int x = 0; x < 10; x++) {
2939
   *     B[x] = 0;
2940
   *     if (x==3 ? 1 : 0) {
2941
   *       for (int y = 0; y < 10; y++) {
2942
   *         A[0] = (A[0]) + 1;
2943
   *       }
2944
   *     }
2945
   *   }
2946
   * }
2947
   */
2948

2949
  stmt = registerize(stmt);
2950

2951
  /*
2952
   * if (x==2 ? 1 : 0) {
2953
   *   int A_1 = 0;
2954
   *   for (int x = 0; x < 10; x++) {
2955
   *     B[x] = 0;
2956
   *     if (x==3 ? 1 : 0) {
2957
   *       for (int y = 0; y < 10; y++) {
2958
   *         A_1 = A_1 + 1;
2959
   *       }
2960
   *     }
2961
   *   }
2962
   *   A[0] = A_1;
2963
   * }
2964
   */
2965

2966
  std::ostringstream oss;
2967
  oss << *stmt;
2968

2969
  const std::string& verification_pattern =
2970
      R"IR(
2971
# CHECK: if (x==2
2972
# CHECK:   int A_1 = 0;
2973
# CHECK:   for (int x
2974
# CHECK:     B[x] = 0;
2975
# CHECK:     if (x==3
2976
# CHECK:       for (int y
2977
# CHECK:         A_1 = A_1 + 1;
2978
# CHECK:       }
2979
# CHECK:     }
2980
# CHECK:   }
2981
# CHECK:  A[0] = A_1;
2982
# CHECK: })IR";
2983

2984
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
2985
}
2986

2987
// Test the negative case of the hiddenAccess split, where the hoisted access is
2988
// never unhidden at a higher scope and registerization occurs at the lower
2989
// scope.
2990
TEST(Registerizer, RegisterizerHiddenAccessNo) {
2991
  BufHandle a("A", {10}, kInt);
2992
  BufHandle b("B", {10}, kInt);
2993
  VarHandle x("x", kInt);
2994
  VarHandle y("y", kInt);
2995
  StmtPtr stmt = Block::make({Cond::make(
2996
      CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
2997
      Block::make({For::make(
2998
          x,
2999
          0,
3000
          10,
3001
          Block::make(
3002
              {Store::make(b, {x}, 0),
3003
               // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3004
               Cond::make(
3005
                   CompareSelect::make(x, 3, CompareSelectOperation::kEQ),
3006
                   For::make(
3007
                       y,
3008
                       0,
3009
                       10,
3010
                       Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3011
                   nullptr)}))}),
3012
      nullptr)});
3013

3014
  /*
3015
   * if (x==2 ? 1 : 0) {
3016
   *   A[0] = 0;
3017
   *   for (int x = 0; x < 10; x++) {
3018
   *     B[x] = 0;
3019
   *     if (x==3 ? 1 : 0) {
3020
   *       for (int y = 0; y < 10; y++) {
3021
   *         A[0] = (A[0]) + 1;
3022
   *       }
3023
   *     }
3024
   *   }
3025
   * }
3026
   */
3027

3028
  stmt = registerize(stmt);
3029

3030
  /*
3031
   * if (x==2 ? 1 : 0) {
3032
   *   for (int x = 0; x < 10; x++) {
3033
   *     B[x] = 0;
3034
   *     if (x==3 ? 1 : 0) {
3035
   *       int A_1 = A[0];
3036
   *       for (int y = 0; y < 10; y++) {
3037
   *         A_1 = A_1 + 1;
3038
   *       }
3039
   *       A[0] = A_1;
3040
   *     }
3041
   *   }
3042
   * }
3043
   */
3044

3045
  std::ostringstream oss;
3046
  oss << *stmt;
3047

3048
  const std::string& verification_pattern =
3049
      R"IR(
3050
# CHECK: if (x==2
3051
# CHECK:   for (int x
3052
# CHECK:     B[x] = 0;
3053
# CHECK:     if (x==3
3054
# CHECK:       int A_1 = A[0];
3055
# CHECK:       for (int y
3056
# CHECK:         A_1 = A_1 + 1;
3057
# CHECK:       }
3058
# CHECK:       A[0] = A_1;
3059
# CHECK:     }
3060
# CHECK:   }
3061
# CHECK: })IR";
3062

3063
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3064
}
3065

3066
// In this case the conditional access must be hoisted by two loops, there are
3067
// two accesses here one is unhidden and the other isnt. A[0] can be
3068
// registerized but B[0] cannot.
3069
TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) {
3070
  BufHandle a("A", {10}, kInt);
3071
  BufHandle b("B", {10}, kInt);
3072
  VarHandle x("x", kInt);
3073
  VarHandle y("y", kInt);
3074
  StmtPtr stmt = Block::make({Cond::make(
3075
      CompareSelect::make(x, 2, CompareSelectOperation::kEQ),
3076
      Block::make(
3077
          {Store::make(a, {0}, 0),
3078
           For::make(
3079
               x,
3080
               0,
3081
               10,
3082
               For::make(
3083
                   y,
3084
                   0,
3085
                   10,
3086
                   Block::make({Cond::make(
3087
                       CompareSelect::make(y, 3, CompareSelectOperation::kEQ),
3088
                       Block::make(
3089
                           {Store::make(
3090
                                a, {0}, Add::make(Load::make(a, {0}), 1)),
3091
                            Store::make(
3092
                                b, {0}, Add::make(Load::make(b, {0}), 1))}),
3093
                       nullptr)})))}),
3094
      nullptr)});
3095

3096
  /*
3097
   * if (x==2 ? 1 : 0) {
3098
   *   A[0] = 0;
3099
   *   for (int x = 0; x < 10; x++) {
3100
   *     for (int y = 0; y < 10; y++) {
3101
   *       if (y==3 ? 1 : 0) {
3102
   *         A[0] = (A[0]) + 1;
3103
   *         B[0] = (B[0]) + 1;
3104
   *       }
3105
   *     }
3106
   *   }
3107
   * }
3108
   */
3109

3110
  stmt = registerize(stmt);
3111

3112
  /*
3113
   * if (x==2 ? 1 : 0) {
3114
   *   int A_1 = 0;
3115
   *   for (int x = 0; x < 10; x++) {
3116
   *     for (int y = 0; y < 10; y++) {
3117
   *       if (y==3 ? 1 : 0) {
3118
   *         A_1 = A_1 + 1;
3119
   *         B[0] = (B[0]) + 1;
3120
   *       }
3121
   *     }
3122
   *   }
3123
   *   A[0] = A_1;
3124
   * }
3125
   */
3126

3127
  std::ostringstream oss;
3128
  oss << *stmt;
3129

3130
  const std::string& verification_pattern =
3131
      R"IR(
3132
# CHECK: if (x==2
3133
# CHECK:   int A_1 = 0;
3134
# CHECK:   for (int x
3135
# CHECK:     for (int y
3136
# CHECK:       if (y==3
3137
# CHECK:         A_1 = A_1 + 1;
3138
# CHECK:         B[0] = (B[0]) + 1;
3139
# CHECK:       }
3140
# CHECK:     }
3141
# CHECK:   }
3142
# CHECK:  A[0] = A_1;
3143
# CHECK: })IR";
3144

3145
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3146
}
3147

3148
// Accesses are registerized inside two conditions, but the immediate parent is
3149
// not a condition.
3150
TEST(Registerizer, RegisterizerTwoConditionalLoops) {
3151
  BufHandle a("A", {1}, kInt);
3152
  VarHandle x("x", kInt);
3153
  StmtPtr stmt = Block::make(
3154
      {Cond::make(
3155
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
3156
           For::make(
3157
               x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3158
           nullptr),
3159
       Cond::make(
3160
           CompareSelect::make(x, 5, CompareSelectOperation::kGT),
3161
           For::make(
3162
               x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3163
           nullptr)});
3164

3165
  /*
3166
   * if (x<5 ? 1 : 0) {
3167
   *   for (int x = 0; x < 10; x++) {
3168
   *     A[0] = (A[0]) + 1;
3169
   *   }
3170
   * }
3171
   * if (x>5 ? 1 : 0) {
3172
   *   for (int x = 0; x < 10; x++) {
3173
   *     A[0] = (A[0]) + 1;
3174
   *   }
3175
   * }
3176
   */
3177

3178
  stmt = registerize(stmt);
3179

3180
  /*
3181
   * if (x<5 ? 1 : 0) {
3182
   *   int A_1 = A[0];
3183
   *   for (int x = 0; x < 10; x++) {
3184
   *     A_1 = A_1 + 1;
3185
   *   }
3186
   *   A[0] = A_1;
3187
   * }
3188
   * if (x>5 ? 1 : 0) {
3189
   *   int A_2 = A[0];
3190
   *   for (int x = 0; x < 10; x++) {
3191
   *     A_2 = A_2 + 1;
3192
   *   }
3193
   *   A[0] = A_2;
3194
   * }
3195
   */
3196

3197
  std::ostringstream oss;
3198
  oss << *stmt;
3199

3200
  const std::string& verification_pattern =
3201
      R"IR(
3202
# CHECK: if (x<5
3203
# CHECK:   int A_1 = A[0];
3204
# CHECK:   for (int x
3205
# CHECK:     A_1 = A_1 + 1;
3206
# CHECK:   }
3207
# CHECK:   A[0] = A_1;
3208
# CHECK: }
3209
# CHECK: if (x>5
3210
# CHECK:   int A_2 = A[0];
3211
# CHECK:   for (int x
3212
# CHECK:     A_2 = A_2 + 1;
3213
# CHECK:   }
3214
# CHECK:   A[0] = A_2;
3215
# CHECK: })IR";
3216

3217
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3218
}
3219

3220
// Accesses are registerized inside two conditions, cut in the middle.
3221
TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) {
3222
  BufHandle a("A", {1}, kInt);
3223
  VarHandle x("x", kInt);
3224
  StmtPtr stmt = Block::make(
3225
      {Cond::make(
3226
           CompareSelect::make(x, 5, CompareSelectOperation::kLT),
3227
           For::make(
3228
               x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3229
           nullptr),
3230
       For::make(x, 0, 10, Store::make(a, {x}, 1)),
3231
       Cond::make(
3232
           CompareSelect::make(x, 5, CompareSelectOperation::kGT),
3233
           For::make(
3234
               x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),
3235
           nullptr)});
3236

3237
  /*
3238
   * if (x<5 ? 1 : 0) {
3239
   *   for (int x = 0; x < 10; x++) {
3240
   *     A[0] = (A[0]) + 1;
3241
   *   }
3242
   * }
3243
   * for (int x = 0; x < 10; x++) {
3244
   *   A[x] = 1;
3245
   * }
3246
   * if (x>5 ? 1 : 0) {
3247
   *   for (int x = 0; x < 10; x++) {
3248
   *     A[0] = (A[0]) + 1;
3249
   *   }
3250
   * }
3251
   */
3252

3253
  stmt = registerize(stmt);
3254

3255
  /*
3256
   * if (x<5 ? 1 : 0) {
3257
   *   int A_1 = A[0];
3258
   *   for (int x = 0; x < 10; x++) {
3259
   *     A_1 = A_1 + 1;
3260
   *   }
3261
   *   A[0] = A_1;
3262
   * }
3263
   * for (int x = 0; x < 10; x++) {
3264
   *   A[x] = 1;
3265
   * }
3266
   * if (x>5 ? 1 : 0) {
3267
   *   int A_2 = A[0];
3268
   *   for (int x = 0; x < 10; x++) {
3269
   *     A_2 = A_2 + 1;
3270
   *   }
3271
   *   A[0] = A_2;
3272
   * }
3273
   */
3274

3275
  std::ostringstream oss;
3276
  oss << *stmt;
3277

3278
  const std::string& verification_pattern =
3279
      R"IR(
3280
# CHECK: if (x<5
3281
# CHECK:   int A_1 = A[0];
3282
# CHECK:   for (int x
3283
# CHECK:     A_1 = A_1 + 1;
3284
# CHECK:   }
3285
# CHECK:   A[0] = A_1;
3286
# CHECK: }
3287
# CHECK: for (int x
3288
# CHECK:  A[x] = 1;
3289
# CHECK: if (x>5
3290
# CHECK:   int A_2 = A[0];
3291
# CHECK:   for (int x
3292
# CHECK:     A_2 = A_2 + 1;
3293
# CHECK:   }
3294
# CHECK:   A[0] = A_2;
3295
# CHECK: })IR";
3296

3297
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3298
}
3299

3300
// references a Let var in a local scope which cannot be hoisted out of the
3301
// loop.
3302
TEST(Registerizer, RegisterizerLoopLetVar) {
3303
  BufHandle a("A", {10}, kInt);
3304
  VarHandle x("x", kInt);
3305
  VarHandle y("y", kInt);
3306
  StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make(
3307
      x,
3308
      0,
3309
      10,
3310
      Block::make(
3311
          {Let::make(y, 30),
3312
           Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}));
3313

3314
  /*
3315
   * for (int x = 0; x < 10; x++) {
3316
   *   int y = 30;
3317
   *   A[y] = x + (A[y]);
3318
   * }
3319
   */
3320

3321
  std::ostringstream before;
3322
  before << *stmt;
3323

3324
  // No change.
3325
  stmt = registerize(stmt);
3326

3327
  std::ostringstream after;
3328
  after << *stmt;
3329

3330
  ASSERT_EQ(before.str(), after.str());
3331
}
3332

3333
// references a Let var in an outer scope that does not prevent hoisting the
3334
// initializer.
3335
TEST(Registerizer, RegisterizerLoopLetVarOuter) {
3336
  BufHandle a("A", {10}, kInt);
3337
  VarHandle x("x", kInt);
3338
  VarHandle y("y", kInt);
3339
  StmtPtr stmt = Block::make(
3340
      {Let::make(y, 30),
3341
       For::make(
3342
           x,
3343
           0,
3344
           10,
3345
           Block::make(
3346
               {Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))});
3347

3348
  /*
3349
   * int y = 30;
3350
   * for (int x = 0; x < 10; x++) {
3351
   *   A[y] = x + (A[y]);
3352
   * }
3353
   */
3354

3355
  stmt = registerize(stmt);
3356

3357
  /*
3358
   * int y = 30;
3359
   * int A_1 = A[y];
3360
   * for (int x = 0; x < 10; x++) {
3361
   *   A_1 = A_1 + x;
3362
   * }
3363
   * A[y] = A_1;
3364
   */
3365

3366
  std::ostringstream oss;
3367
  oss << *stmt;
3368

3369
  const std::string& verification_pattern =
3370
      R"IR(
3371
# CHECK: int y = 30;
3372
# CHECK: int A_1 = A[y];
3373
# CHECK: for (int x
3374
# CHECK:   A_1 = A_1 + x;
3375
# CHECK: A[y] = A_1;)IR";
3376

3377
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3378
}
3379

3380
// Okay so the registerizer generally goes after index flattening, but just in
3381
// case. Test multi index registerization.
3382
TEST(Registerizer, RegisterizerMultiDim) {
3383
  BufHandle a("A", {3, 4, 5}, kInt);
3384
  VarHandle x("x", kInt);
3385
  StmtPtr stmt = Block::make(
3386
      {Store::make(a, {0, 1, 2}, 0),
3387
       For::make(
3388
           x,
3389
           0,
3390
           10,
3391
           Block::make({Store::make(
3392
               a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))});
3393

3394
  /*
3395
   * A[0, 1, 2] = 0;
3396
   * for (int x = 0; x < 10; x++) {
3397
   *   A[0, 1, 2] = (A[0, 1, 2]) + x;
3398
   * }
3399
   */
3400

3401
  stmt = registerize(stmt);
3402

3403
  /*
3404
   * int A_1 = 0;
3405
   * for (int x = 0; x < 10; x++) {
3406
   *   A_1 = x + A_1;
3407
   * }
3408
   * A[0, 1, 2] = A_1;
3409
   */
3410

3411
  std::ostringstream oss;
3412
  oss << *stmt;
3413

3414
  const std::string& verification_pattern =
3415
      R"IR(
3416
# CHECK: int A_1 = 0;
3417
# CHECK: for (int x = 0; x < 10; x++)
3418
# CHECK-NOT: A[
3419
# CHECK:   A_1 =
3420
# CHECK: A[0, 1, 2] = A_1;)IR";
3421

3422
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3423
}
3424

3425
// Wont registerize if only some dims match, but will still registerize distinct
3426
// elements.
3427
TEST(Registerizer, RegisterizerMultiDimPartial) {
3428
  BufHandle a("A", {3, 4, 5}, kInt);
3429
  VarHandle x("x", kInt);
3430
  StmtPtr stmt = Block::make(
3431
      {Store::make(a, {0, 1, 2}, 0),
3432
       For::make(
3433
           x,
3434
           0,
3435
           10,
3436
           Block::make({Store::make(
3437
               a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))});
3438

3439
  /*
3440
   * A[0, 1, 2] = 0;
3441
   * for (int x = 0; x < 10; x++) {
3442
   *   A[0, 2, 2] = (A[0, 1, 4]) + x;
3443
   * }
3444
   */
3445

3446
  stmt = registerize(stmt);
3447

3448
  /*
3449
   * A[0, 1, 2] = 0;
3450
   * int A_1 = A[0, 1, 4];
3451
   * int A_2 = A[0, 2, 2];
3452
   * for (int x = 0; x < 10; x++) {
3453
   *   A_2 = A_1 + x;
3454
   * }
3455
   * A[0, 2, 2] = A_2;
3456
   */
3457

3458
  std::ostringstream oss;
3459
  oss << *stmt;
3460

3461
  const std::string& verification_pattern =
3462
      R"IR(
3463
# CHECK: A[0, 1, 2] = 0;
3464
# CHECK: int A_1 = A[0, 1, 4];
3465
# CHECK: int A_2 = A[0, 2, 2];
3466
# CHECK: for (
3467
# CHECK:   A_2 = A_1 + x;
3468
# CHECK: A[0, 2, 2] = A_2;)IR";
3469

3470
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3471
}
3472

3473
// If they could overlap across all dimensions we cannot registerize.
3474
TEST(Registerizer, RegisterizerMultiDimOverlap) {
3475
  BufHandle a("A", {3, 4, 5}, kInt);
3476
  VarHandle x("x", kInt);
3477
  VarHandle y("y", kInt);
3478
  StmtPtr stmt = Block::make(
3479
      {Store::make(a, {0, 1, 2}, 0),
3480
       For::make(
3481
           x,
3482
           0,
3483
           10,
3484
           Block::make({Store::make(
3485
               a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))});
3486
  stmt = IRSimplifier::simplify(stmt);
3487

3488
  /*
3489
   * A[0, 1, 2] = 0;
3490
   * for (int x = 0; x < 10; x++) {
3491
   *   A[0, x, 2] = (A[y, 2, 2]) + x;
3492
   * }
3493
   */
3494

3495
  std::ostringstream before;
3496
  before << *stmt;
3497

3498
  // No change.
3499
  stmt = registerize(stmt);
3500

3501
  std::ostringstream after;
3502
  after << *stmt;
3503

3504
  ASSERT_EQ(before.str(), after.str());
3505
}
3506

3507
// But, if one dimension is known to be distinct they do not overlap.
3508
TEST(Registerizer, RegisterizerMultiDimPartialOverlap) {
3509
  BufHandle a("A", {3, 4, 5}, kInt);
3510
  VarHandle x("x", kInt);
3511
  VarHandle y("y", kInt);
3512
  StmtPtr stmt = Block::make(
3513
      {Store::make(a, {0, 1, 2}, 0),
3514
       For::make(
3515
           x,
3516
           0,
3517
           10,
3518
           Block::make({Store::make(
3519
               a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))});
3520

3521
  /*
3522
   * A[0, 1, 2] = 0;                          <---- 2nd dim overlaps with store.
3523
   * for (int x = 0; x < 10; x++) {
3524
   *   A[0, x, 2] = (A[y, 2, 4]) + x;           <---- 3rd dim has constant diff.
3525
   * }
3526
   */
3527

3528
  stmt = registerize(stmt);
3529

3530
  /*
3531
   * A[0, 1, 2] = 0;
3532
   * int A_1 = A[y, 2, 4];
3533
   * for (int x = 0; x < 10; x++) {
3534
   *   A[0, x, 2] = A_1 + x;
3535
   * }
3536
   */
3537

3538
  std::ostringstream oss;
3539
  oss << *stmt;
3540

3541
  const std::string& verification_pattern =
3542
      R"IR(
3543
# CHECK: A[0, 1, 2] = 0;
3544
# CHECK: int A_1 = A[y, 2, 4];
3545
# CHECK: for (
3546
# CHECK:   A[0, x, 2] = A_1 + x;
3547
# CHECK: })IR";
3548

3549
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3550
}
3551

3552
// A 3D reduction with different input dimensionality.
3553
TEST(Registerizer, RegisterizerMultiDim3DReduction1) {
3554
  BufHandle a("A", {10}, kInt);
3555
  BufHandle b("B", {10, 10}, kInt);
3556
  BufHandle c("C", {10, 10, 10}, kInt);
3557
  VarHandle x("x", kInt);
3558
  VarHandle y("y", kInt);
3559
  VarHandle z("z", kInt);
3560
  StmtPtr stmt = For::make(
3561
      x,
3562
      0,
3563
      10,
3564
      For::make(
3565
          y,
3566
          0,
3567
          10,
3568
          For::make(
3569
              z,
3570
              0,
3571
              10,
3572
              Store::make(
3573
                  c,
3574
                  {x, y, z},
3575
                  Add::make(
3576
                      Load::make(c, {x, y, z}),
3577
                      Mul::make(Load::make(b, {x, y}), Load::make(a, {x})))))));
3578

3579
  /*
3580
   * for (int x = 0; x < 10; x++) {
3581
   *   for (int y = 0; y < 10; y++) {
3582
   *     for (int z = 0; z < 10; z++) {
3583
   *       C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]);
3584
   *     }
3585
   *   }
3586
   * }
3587
   */
3588

3589
  // We can registerize the A and B access since they can be hoisted before
3590
  // hitting a dependent loop var.
3591

3592
  stmt = registerize(stmt);
3593

3594
  /*
3595
   * for (int x = 0; x < 10; x++) {
3596
   *   int A_1 = A[x];
3597
   *   for (int y = 0; y < 10; y++) {
3598
   *     int B_1 = B[x, y];
3599
   *     for (int z = 0; z < 10; z++) {
3600
   *       C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
3601
   *     }
3602
   *   }
3603
   * }
3604
   */
3605

3606
  std::ostringstream oss;
3607
  oss << *stmt;
3608

3609
  const std::string& verification_pattern =
3610
      R"IR(
3611
# CHECK: for (int x
3612
# CHECK:   int A_1 = A[x];
3613
# CHECK:   for (int y
3614
# CHECK:     int B_1 = B[x, y];
3615
# CHECK:       for (int z
3616
# CHECK:         C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
3617
# CHECK: })IR";
3618

3619
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3620
}
3621

3622
// A 3D reduction with the same smaller dimensionality using different loop
3623
// vars.
3624
TEST(Registerizer, RegisterizerMultiDim3DReduction2) {
3625
  BufHandle a("A", {10}, kInt);
3626
  BufHandle b("B", {10}, kInt);
3627
  BufHandle c("C", {10}, kInt);
3628
  VarHandle x("x", kInt);
3629
  VarHandle y("y", kInt);
3630
  VarHandle z("z", kInt);
3631
  StmtPtr stmt = For::make(
3632
      x,
3633
      0,
3634
      10,
3635
      // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
3636
      For::make(
3637
          y,
3638
          0,
3639
          10,
3640
          For::make(
3641
              z,
3642
              0,
3643
              10,
3644
              Store::make(
3645
                  c,
3646
                  {x},
3647
                  Add::make(
3648
                      Load::make(c, {x}),
3649
                      Mul::make(Load::make(b, {y}), Load::make(a, {x})))))));
3650

3651
  /*
3652
   * for (int x = 0; x < 10; x++) {
3653
   *   for (int y = 0; y < 10; y++) {
3654
   *     for (int z = 0; z < 10; z++) {
3655
   *       C[x] = (C[x]) + (B[y]) * (A[x]);
3656
   *     }
3657
   *   }
3658
   * }
3659
   */
3660

3661
  // We can registerize all accesses, the A and C access can be hoisted to the
3662
  // outer loop since they depend only on it's loop var while the B can only be
3663
  // raised to the loop of y.
3664

3665
  stmt = registerize(stmt);
3666

3667
  /*
3668
   * for (int x = 0; x < 10; x++) {
3669
   *   int A_1 = A[x];
3670
   *   int C_1 = C[x];
3671
   *   for (int y = 0; y < 10; y++) {
3672
   *     int B_1 = B[y];
3673
   *     for (int z = 0; z < 10; z++) {
3674
   *       C_1 = A_1 * B_1 + C_1;
3675
   *     }
3676
   *   }
3677
   *   C[x] = C_1;
3678
   * }
3679
   */
3680

3681
  std::ostringstream oss;
3682
  oss << *stmt;
3683

3684
  const std::string& verification_pattern =
3685
      R"IR(
3686
# CHECK: for (int x
3687
# CHECK:   int A_1 = A[x];
3688
# CHECK:   int C_1 = C[x];
3689
# CHECK:   for (int y
3690
# CHECK:     int B_1 = B[y];
3691
# CHECK:       for (int z
3692
# CHECK:         C_1 = A_1 * B_1 + C_1;
3693
# CHECK:       }
3694
# CHECK:     }
3695
# CHECK:   C[x] = C_1;
3696
# CHECK: })IR";
3697

3698
  torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
3699
}
3700

3701
} // namespace jit
3702
} // namespace torch
3703

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.