pytorch
3702 строки · 87.4 Кб
1#include <gtest/gtest.h>2#include "test/cpp/tensorexpr/test_base.h"3
4#include "test/cpp/tensorexpr/test_utils.h"5#include "torch/csrc/jit/tensorexpr/ir_simplifier.h"6#include "torch/csrc/jit/tensorexpr/registerizer.h"7
8#include <iostream>9
10namespace torch {11namespace jit {12using namespace torch::jit::tensorexpr;13
14// Can replace a simple scalar access with a local variable.
15TEST(Registerizer, RegisterizerSimple) {16BufHandle a("A", {1}, kInt);17VarHandle x("x", kInt);18StmtPtr stmt = Block::make(19{Store::make(a, {0}, 0),20For::make(21x,220,2310,24Block::make(25{Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});26
27/*28* A[0] = 0;
29* for (int x = 0; x < 10; x++) {
30* A[0] = (A[0]) + x;
31* }
32*/
33
34stmt = registerize(stmt);35
36/*37* int A_1 = 0;
38* for (int x = 0; x < 10; x++) {
39* A_1 = x + A_1;
40* }
41* A[0] = A_1;
42*/
43
44std::ostringstream oss;45oss << *stmt;46
47const std::string& verification_pattern =48R"IR(49# CHECK: int A_1 = 0;
50# CHECK: for (int x = 0; x < 10; x++)
51# CHECK-NOT: A[
52# CHECK: A_1 =
53# CHECK: A[0] = A_1;)IR";54
55torch::jit::testing::FileCheck().run(verification_pattern, oss.str());56}
57
58// Won't do replacement of a loop access.
59TEST(Registerizer, RegisterizerLoop) {60BufHandle a("A", {10}, kInt);61VarHandle x("x", kInt);62StmtPtr stmt = Block::make(63{Store::make(a, {0}, 0),64For::make(65x,660,6710,68Block::make(69{Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});70
71/*72* A[0] = 0;
73* for (int x = 0; x < 10; x++) {
74* A[x] = (A[x]) + x;
75* }
76*/
77
78// No change.79stmt = registerize(stmt);80
81/*82* A[0] = 0;
83* for (int x = 0; x < 10; x++) {
84* A[x] = (A[x]) + x;
85* }
86*/
87
88std::ostringstream oss;89oss << *stmt;90
91const std::string& verification_pattern =92R"IR(93# CHECK-NOT: int
94# CHECK: A[0] = 0;
95# CHECK: for (int x = 0; x < 10; x++)
96# CHECK-NOT: A_
97# CHECK: A[x] =
98# CHECK-NOT: A[0] = A_1;)IR";99
100torch::jit::testing::FileCheck().run(verification_pattern, oss.str());101}
102
103// Won't replace even if the load is a fixed scalar, since the store could
104// invalidate it.
105TEST(Registerizer, RegisterizerLoopFixedLoad) {106BufHandle a("A", {1}, kInt);107VarHandle x("x", kInt);108StmtPtr stmt = Block::make(109{Store::make(a, {0}, 0),110For::make(111x,1120,11310,114Block::make(115{Store::make(a, {x}, Add::make(Load::make(a, {0}), x))}))});116
117/*118* A[0] = 0;
119* for (int x = 0; x < 10; x++) {
120* A[x] = (A[0]) + x;
121* }
122*/
123
124// No change.125stmt = registerize(stmt);126
127/*128* A[0] = 0;
129* for (int x = 0; x < 10; x++) {
130* A[x] = (A[0]) + x;
131* }
132*/
133
134std::ostringstream oss;135oss << *stmt;136
137const std::string& verification_pattern =138R"IR(139# CHECK-NOT: int
140# CHECK: A[0] = 0;
141# CHECK: for (int x = 0; x < 10; x++)
142# CHECK-NOT: A_
143# CHECK: A[x] =
144# CHECK-NOT: A[0] = A_1;)IR";145
146torch::jit::testing::FileCheck().run(verification_pattern, oss.str());147}
148
149// We can registerize accesses that occur entirely within inner scopes, even if
150// they depend on the loop var.
151TEST(Registerizer, RegisterizerLoopInternal) {152BufHandle a("A", {1}, kInt);153VarHandle x("x", kInt);154StmtPtr stmt = Block::make({For::make(155x,1560,15710,158Block::make(159{Store::make(a, {x}, Add::make(Load::make(a, {x}), x)),160Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});161
162/*163* for (int x = 0; x < 10; x++) {
164* A[x] = (A[x]) + x;
165* A[x] = (A[x]) + x;
166* }
167*/
168
169stmt = registerize(stmt);170
171// TODO: the order of terms in addition changes and in general depends on172// some hash value. This results in unpredictable swaps of the operands from173// random changes, which is not great. Ideally, we should ensure some174// specific order (ideally, the original one).175/*176* for (int x = 0; x < 10; x++) {
177* int A_1 = A[x];
178* A_1 = x + A_1;
179* A_1 = x + A_1;
180* A[x] = A_1;
181* }
182*/
183
184std::ostringstream oss;185oss << *stmt;186
187const std::string& verification_pattern =188R"IR(189# CHECK: for (int x = 0; x < 10; x++)
190# CHECK: int A_1 = A[x];
191# CHECK: A_1 = A_1 + x;
192# CHECK: A_1 = A_1 + x;
193# CHECK: A[x] = A_1;
194# CHECK: })IR";195
196torch::jit::testing::FileCheck().run(verification_pattern, oss.str());197}
198
199// An access can be overlapped by another read in the same Expr. In this case
200// B[z] and B[y] overlap and prevent registerization of both accesses.
201TEST(Registerizer, RegisterizerLoopInternalLoadOverlap) {202BufHandle a("A", {10}, kInt);203BufHandle b("B", {10}, kInt);204VarHandle x("x", kInt);205VarHandle y("y", kInt);206VarHandle z("z", kInt);207StmtPtr stmt = Block::make({For::make(208x,2090,21010,211Store::make(a, {x}, Add::make(Load::make(b, {y}), Load::make(b, {z}))))});212stmt = IRSimplifier::simplify(stmt);213
214/*215* for (int x = 0; x < 10; x++) {
216* A[x] = (B[y]) + (B[z]);
217* }
218*/
219
220std::ostringstream before;221before << *stmt;222
223// No change.224stmt = registerize(stmt);225
226std::ostringstream after;227after << *stmt;228
229ASSERT_EQ(before.str(), after.str());230}
231
232TEST(Registerizer, RegisterizerLoopInternalRepeated) {233BufHandle a("A", {1}, kInt);234VarHandle x("x", kInt);235StmtPtr stmt = Block::make(236{For::make(237x,2380,23910,240Block::make(241{Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),242Store::make(a, {0}, Add::make(Load::make(a, {1}), x))})),243For::make(244x,2450,24610,247Block::make(248{Store::make(a, {0}, Add::make(Load::make(a, {1}), x)),249Store::make(a, {0}, Add::make(Load::make(a, {1}), x))}))250
251});252
253/*254* for (int x = 0; x < 10; x++) {
255* A[0] = x + (A[1]);
256* A[0] = x + (A[1]);
257* }
258* for (int x = 0; x < 10; x++) {
259* A[0] = x + (A[1]);
260* A[0] = x + (A[1]);
261* }
262*/
263
264stmt = registerize(stmt);265
266/*267* int A_1 = A[1];
268* int A_2 = A[0];
269* for (int x = 0; x < 10; x++) {
270* A_2 = A_1 + x;
271* A_2 = A_1 + x;
272* }
273* for (int x = 0; x < 10; x++) {
274* A_2 = A_1 + x;
275* A_2 = A_1 + x;
276* }
277* A[0] = A_2;
278*/
279
280std::ostringstream oss;281oss << *stmt;282
283const std::string& verification_pattern =284R"IR(285# CHECK: int A_1 = A[1];
286# CHECK: int A_2 = A[0];
287# CHECK: for (int x = 0; x < 10; x++)
288# CHECK: A_2 = A_1 + x;
289# CHECK: A_2 = A_1 + x;
290# CHECK: }
291# CHECK: for (int x = 0; x < 10; x++)
292# CHECK: A_2 = A_1 + x;
293# CHECK: A_2 = A_1 + x;
294# CHECK: }
295# CHECK-NOT: A[1]
296# CHECK: A[0] = A_2;
297# CHECK-NOT: A[1]
298# CHECK: })IR";299
300torch::jit::testing::FileCheck().run(verification_pattern, oss.str());301}
302
303TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapLoopVar) {304BufHandle a("A", {1}, kInt);305VarHandle x("x", kInt);306StmtPtr stmt = Block::make(307{For::make(308x,3090,31010,311Block::make(312{Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),313Store::make(a, {0}, Add::make(Load::make(a, {x}), x))})),314For::make(315x,3160,31710,318Block::make(319{Store::make(a, {0}, Add::make(Load::make(a, {x}), x)),320Store::make(a, {0}, Add::make(Load::make(a, {x}), x))}))321
322});323stmt = IRSimplifier::simplify(stmt);324
325/*326* for (int x = 0; x < 10; x++) {
327* A[0] = (A[x]) + x;
328* A[0] = (A[x]) + x;
329* }
330* for (int x = 0; x < 10; x++) {
331* A[0] = (A[x]) + x;
332* A[0] = (A[x]) + x;
333* }
334*/
335
336std::ostringstream before;337before << *stmt;338
339// No change.340stmt = registerize(stmt);341
342std::ostringstream after;343after << *stmt;344
345ASSERT_EQ(before.str(), after.str());346}
347
348TEST(Registerizer, RegisterizerLoopInternalRepeatedOverlapOther) {349BufHandle a("A", {1}, kInt);350VarHandle x("x", kInt);351VarHandle y("y", kInt);352StmtPtr stmt = IRSimplifier::simplify(Block::make(353{For::make(354x,3550,35610,357Block::make(358{Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),359Store::make(a, {0}, Add::make(x, Load::make(a, {y})))})),360For::make(361x,3620,36310,364Block::make(365{Store::make(a, {0}, Add::make(x, Load::make(a, {y}))),366Store::make(a, {0}, Add::make(x, Load::make(a, {y})))}))367
368}));369
370/*371* for (int x = 0; x < 10; x++) {
372* A[0] = (A[x]) + x;
373* A[0] = (A[x]) + x;
374* }
375* for (int x = 0; x < 10; x++) {
376* A[0] = (A[x]) + x;
377* A[0] = (A[x]) + x;
378* }
379*/
380
381std::ostringstream before;382before << *stmt;383
384// No change.385stmt = registerize(stmt);386
387std::ostringstream after;388after << *stmt;389
390ASSERT_EQ(before.str(), after.str());391}
392
393// Will registerize multiple accesses of different items of the same buffer.
394TEST(Registerizer, RegisterizerMultiVar) {395BufHandle a("A", {2}, kInt);396VarHandle x("x", kInt);397StmtPtr stmt = Block::make({398Store::make(a, {0}, 0),399Store::make(a, {1}, 0),400For::make(401x,4020,40310,404Block::make(405{Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),406Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),407});408
409/*410* A[0] = 0;
411* A[1] = 0;
412* for (int x = 0; x < 10; x++) {
413* A[0] = (A[0]) + x;
414* A[1] = (A[1]) - x;
415* }
416*/
417
418stmt = registerize(stmt);419
420/*421* int A_1 = 0;
422* int A_2 = 0;
423* for (int x = 0; x < 10; x++) {
424* A_2 = x + A_2;
425* A_1 = A_1 - x;
426* }
427* A[1] = A_2;
428* A[0] = A_1;
429*/
430
431std::ostringstream oss;432oss << *stmt;433
434const std::string& verification_pattern =435R"IR(436# CHECK: int A_1 = 0;
437# CHECK: int A_2 = 0;
438# CHECK: for (int x = 0; x < 10; x++)
439# CHECK-NOT: A[
440# CHECK: A_1 =
441# CHECK: A_2 =
442# CHECK: A[1] = A_2
443# CHECK: A[0] = A_1;)IR";444
445torch::jit::testing::FileCheck().run(verification_pattern, oss.str());446}
447
448// Will registerize the valid accesses while skipping invalid replacements.
449TEST(Registerizer, RegisterizerVariableLoad) {450BufHandle a("A", {1}, kInt);451BufHandle b("B", {10}, kInt);452VarHandle x("x", kInt);453VarHandle x2("x", kInt);454StmtPtr stmt = Block::make(455{Store::make(a, {0}, 0),456For::make(x, 0, 10, Store::make(b, {x}, x)),457For::make(458x2,4590,46010,461Block::make({Store::make(462a, {0}, Add::make(Load::make(a, {0}), Load::make(b, {x2})))}))});463
464/*465* A[0] = 0;
466* for (int x = 0; x < 10; x++) {
467* B[x] = x;
468* }
469* for (int x_1 = 0; x_1 < 10; x_1++) {
470* A[0] = (A[0]) + (B[x_1]);
471* }
472*/
473
474stmt = registerize(stmt);475
476/*477* int A_1 = 0;
478* for (int x = 0; x < 10; x++) {
479* B[x] = x;
480* }
481* for (int x_1 = 0; x_1 < 10; x_1++) {
482* A_1 = A_1 + (B[x_1]);
483* }
484* A[0] = A_1;
485*/
486
487std::ostringstream oss;488oss << *stmt;489
490const std::string& verification_pattern =491R"IR(492# CHECK: int A_1 = 0;
493# CHECK: for (int x = 0; x < 10; x++)
494# CHECK: B[x] = x
495# CHECK: for (int x_1 = 0; x_1 < 10; x_1++)
496# CHECK-NOT: A[
497# CHECK: A_1 =
498# CHECK: A[0] = A_1;)IR";499
500torch::jit::testing::FileCheck().run(verification_pattern, oss.str());501}
502
503// Can registerize variable accesses so long as the variable does not change.
504TEST(Registerizer, RegisterizerSymbolicIndices) {505VarHandle i("i", kInt);506VarHandle N("N", kInt);507BufHandle a("A", {N}, kInt);508VarHandle x("x", kInt);509StmtPtr stmt = Block::make(510{Store::make(a, {i}, 0),511For::make(512x,5130,51410,515Block::make(516{Store::make(a, {i}, Add::make(Load::make(a, {i}), x))}))});517
518/*519* A[i] = 0;
520* for (int x = 0; x < 10; x++) {
521* A[i] = (A[i]) + x;
522* }
523*/
524
525stmt = registerize(stmt);526
527/*528* int A_1 = 0;
529* for (int x = 0; x < 10; x++) {
530* A_1 = x + A_1;
531* }
532* A[i] = A_1;
533*/
534
535std::ostringstream oss;536oss << *stmt;537
538const std::string& verification_pattern =539R"IR(540# CHECK: int A_1 = 0;
541# CHECK: for (int x = 0; x < 10; x++)
542# CHECK-NOT: A[
543# CHECK: A_1 =
544# CHECK: A[i] = A_1;)IR";545
546torch::jit::testing::FileCheck().run(verification_pattern, oss.str());547}
548
549// Can registerize accesses dependent on multiple loop vars.
550TEST(Registerizer, RegisterizerMultiLoop) {551BufHandle a("A", {1}, kInt);552VarHandle x("x", kInt);553VarHandle y("y", kInt);554StmtPtr stmt = Block::make(555{Store::make(a, {0}, 0),556For::make(557x,5580,55910,560For::make(561y,5620,56310,564Block::make({Store::make(565a,566{0},567Mul::make(Add::make(Load::make(a, {0}), x), y))})))});568
569/*570* A[0] = 0;
571* for (int x = 0; x < 10; x++) {
572* for (int y = 0; y < 10; y++) {
573* A[0] = x * y + (A[0]) * y;
574* }
575* }
576*/
577
578stmt = registerize(stmt);579
580/*581* int A_1 = 0;
582* for (int x = 0; x < 10; x++) {
583* for (int y = 0; y < 10; y++) {
584* A_1 = x * y + y * A_1;
585* }
586* }
587* A[0] = A_1;
588*/
589
590std::ostringstream oss;591oss << *stmt;592
593const std::string& verification_pattern =594R"IR(595# CHECK: int A_1 = 0;
596# CHECK: for (int x = 0; x < 10; x++)
597# CHECK: for (int y = 0; y < 10; y++)
598# CHECK-NOT: A[
599# CHECK: A_1 =
600# CHECK: A[0] = A_1;)IR";601
602torch::jit::testing::FileCheck().run(verification_pattern, oss.str());603}
604
605// Can registerize correctly if scalars already exist in the program.
606TEST(Registerizer, RegisterizerRepeated) {607BufHandle a("A", {2}, kInt);608VarHandle x("x", kInt);609StmtPtr stmt = Block::make({610Store::make(a, {0}, 0),611Store::make(a, {1}, 0),612For::make(613x,6140,61510,616Block::make(617{Store::make(a, {0}, Add::make(Load::make(a, {0}), x)),618Store::make(a, {1}, Sub::make(Load::make(a, {1}), x))})),619});620
621// Registerize manually to make sure we only replace a single target.622{623registerizer::RegisterizerAnalysis analysis;624stmt->accept(&analysis);625auto candidates = analysis.getCandidates();626ASSERT_EQ(candidates.size(), 2);627
628candidates.pop_back();629registerizer::RegisterizerReplacer replacer(candidates);630stmt = stmt->accept_mutator(&replacer);631}632
633// Re-analyze and replace the second target.634{635registerizer::RegisterizerAnalysis analysis;636stmt->accept(&analysis);637auto candidates = analysis.getCandidates();638ASSERT_EQ(candidates.size(), 1);639
640registerizer::RegisterizerReplacer replacer(candidates);641stmt = stmt->accept_mutator(&replacer);642}643
644std::ostringstream oss;645oss << *stmt;646
647const std::string& verification_pattern =648R"IR(649# CHECK: int A_1 = 0;
650# CHECK: int A_1_1 = 0;
651# CHECK: for (int x = 0; x < 10; x++)
652# CHECK-NOT: A[
653# CHECK: A_1 =
654# CHECK: A_1_1 =
655# CHECK: A[1] = A_1_1;
656# CHECK: A[0] = A_1;)IR";657
658torch::jit::testing::FileCheck().run(verification_pattern, oss.str());659}
660
661// Can registerize the load of A.
662TEST(Registerizer, RegisterizerNoLoads) {663BufHandle a("A", {1}, kInt);664VarHandle x("x", kInt);665StmtPtr stmt = Block::make(666{Store::make(a, {0}, 0),667For::make(668x, 0, 10, Block::make({Store::make(a, {0}, Add::make(x, 1))}))});669
670/*671* A[0] = 0;
672* for (int x = 0; x < 10; x++) {
673* A[0] = x + 1;
674* }
675*/
676
677stmt = registerize(stmt);678
679/*680* int A_1 = 0;
681* for (int x = 0; x < 10; x++) {
682* A_1 = x + 1;
683* }
684* A[0] = A_1;
685*/
686
687std::ostringstream oss;688oss << *stmt;689
690const std::string& verification_pattern =691R"IR(692# CHECK: int A_1 = 0;
693# CHECK: for (int x = 0; x < 10; x++)
694# CHECK-NOT: A[
695# CHECK: A_1 =
696# CHECK: A[0] = A_1;)IR";697
698torch::jit::testing::FileCheck().run(verification_pattern, oss.str());699}
700
701// Can registerize the load of A but not the store of B.
702TEST(Registerizer, RegisterizerNoRepeatedStores) {703BufHandle a("A", {1}, kInt);704BufHandle b("B", {10}, kInt);705VarHandle x("x", kInt);706StmtPtr stmt = Block::make(707{Store::make(a, {0}, 0),708For::make(709x,7100,71110,712Block::make(713{Store::make(b, {x}, Add::make(Load::make(a, {0}), x))}))});714
715/*716* A[0] = 0;
717* for (int x = 0; x < 10; x++) {
718* B[x] = (A[0]) + x;
719* }
720*/
721
722stmt = registerize(stmt);723
724// TODO: its unnecessary to reorder the initializer of A[0], but it's not725// actually worse so lets not worry for now.726
727/*728* int A_1 = 0;
729* for (int x = 0; x < 10; x++) {
730* B[x] = x + A_1;
731* }
732* A[0] = A_1;
733*/
734
735std::ostringstream oss;736oss << *stmt;737
738const std::string& verification_pattern =739R"IR(740# CHECK: int A_1 = 0;
741# CHECK: for (int x = 0; x < 10; x++)
742# CHECK-NOT: A_
743# CHECK: B[x] =
744# CHECK: A[0] = A_1;)IR";745
746torch::jit::testing::FileCheck().run(verification_pattern, oss.str());747}
748
749// Won't registerize if there are multiple accesses which may overlap.
750TEST(Registerizer, RegisterizerMultiVarOverlap) {751BufHandle a("A", {2}, kInt);752VarHandle x("x", kInt);753StmtPtr stmt = Block::make({754Store::make(a, {0}, 0),755Store::make(a, {1}, 0),756For::make(757x,7580,75910,760Block::make(761{Store::make(a, {x}, Add::make(Load::make(a, {0}), x)),762Store::make(a, {x + 1}, Sub::make(Load::make(a, {1}), x))})),763});764stmt = IRSimplifier::simplify(stmt);765
766std::ostringstream before;767before << *stmt;768
769// No change.770stmt = registerize(stmt);771
772std::ostringstream after;773after << *stmt;774
775ASSERT_EQ(before.str(), after.str());776}
777
778TEST(Registerizer, RegisterizerAllocs) {779BufHandle a("A", {2}, kInt);780BufHandle c("C", {1}, kInt);781VarHandle x("x", kInt);782
783BufHandle b("B", {Load::make(c, {0})}, kInt);784
785StmtPtr stmt = Block::make(786{Allocate::make(b),787Store::make(a, {0}, Load::make(c, {0})),788Store::make(b, {0}, 0),789For::make(790x,7910,79210,793Block::make(794{Store::make(b, {0}, Add::make(Load::make(b, {0}), x)),795Store::make(a, {0}, Load::make(c, {0}))})),796Free::make(b)});797
798/*799* Allocate(B, int, {C[0]});
800* A[0] = C[0];
801* B[0] = 0;
802* for (int x = 0; x < 10; x++) {
803* B[0] = (B[0]) + x;
804* A[0] = C[0];
805* }
806* Free(B);
807*/
808
809stmt = registerize(stmt);810
811/*812* int C_1 = C[0];
813* Allocate(B, int, {C_});
814* int A_1 = C_1;
815* int B_1 = 0;
816* for (int x = 0; x < 10; x++) {
817* B_1 = B_1 + x;
818* A_1 = C_1;
819* }
820* B[0] = B_1;
821* A[0] = A_1;
822* Free(B);
823*/
824
825std::ostringstream oss;826oss << *stmt;827
828const std::string& verification_pattern =829R"IR(830# CHECK: int C_1 = C[0];
831# CHECK: Allocate(B
832# CHECK: int A_1 = C_1;
833# CHECK: int B_1 = 0;
834# CHECK: for (int x = 0; x < 10; x++)
835# CHECK: B_1 =
836# CHECK: A_1 = C_
837# CHECK: B[0] = B_1;
838# CHECK: A[0] = A_1;
839# CHECK: Free(B)IR";840
841torch::jit::testing::FileCheck().run(verification_pattern, oss.str());842}
843
844TEST(Registerizer, RegisterizerNoInitializer) {845BufHandle a("A", {1}, kInt);846VarHandle x("x", kInt);847StmtPtr stmt = Block::make({For::make(848x,8490,85010,851Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});852
853/*854* for (int x = 0; x < 10; x++) {
855* A[0] = (A[0]) + x;
856* }
857*/
858
859stmt = registerize(stmt);860
861/*862* int A_1 = A[0];
863* for (int x = 0; x < 10; x++) {
864* A_1 = x + A_1;
865* }
866* A[0] = A_1;
867*/
868
869std::ostringstream oss;870oss << *stmt;871
872const std::string& verification_pattern =873R"IR(874# CHECK: int A_1 = A[0];
875# CHECK: for (int x = 0; x < 10; x++)
876# CHECK-NOT: A[
877# CHECK: A_1 =
878# CHECK: A[0] = A_1;)IR";879
880torch::jit::testing::FileCheck().run(verification_pattern, oss.str());881}
882
883TEST(Registerizer, RegisterizerNoInitializerLoopVar) {884BufHandle a("A", {1}, kInt);885VarHandle x("x", kInt);886StmtPtr stmt = Block::make({For::make(887x,8880,88910,890Block::make({Store::make(a, {x}, Add::make(Load::make(a, {x}), x))}))});891stmt = IRSimplifier::simplify(stmt);892
893/*894* for (int x = 0; x < 10; x++) {
895* A[x] = (A[x]) + x;
896* }
897*/
898
899std::ostringstream before;900before << *stmt;901
902// No change.903stmt = registerize(stmt);904
905std::ostringstream after;906after << *stmt;907
908ASSERT_EQ(before.str(), after.str());909}
910
911TEST(Registerizer, RegisterizerLoadThenStore) {912BufHandle a("A", {1}, kInt);913BufHandle b("B", {1}, kInt);914VarHandle x("x", kInt);915StmtPtr stmt = Block::make({For::make(916x,9170,91810,919Block::make(920{Store::make(b, {0}, Add::make(Load::make(a, {0}), x)),921Store::make(a, {0}, Load::make(b, {0}))}))});922
923/*924* for (int x = 0; x < 10; x++) {
925* B[0] = (A[0]) + x;
926* A[0] = B[0];
927* }
928*/
929
930stmt = registerize(stmt);931
932/*933* int A_1 = A[0];
934* int B_1 = B[0];
935* for (int x = 0; x < 10; x++) {
936* B_1 = x + A_1;
937* A_1 = B_1;
938* }
939* B[0] = B_1;
940* A[0] = A_1;
941*/
942
943std::ostringstream oss;944oss << *stmt;945
946const std::string& verification_pattern =947R"IR(948# CHECK: int A_1 = A[0];
949# CHECK: int B_1 = B[0];
950# CHECK: for (int x = 0; x < 10; x++)
951# CHECK-NOT: B[
952# CHECK: B_1 =
953# CHECK-NOT: A[
954# CHECK: A_1 = B_
955# CHECK: B[0] = B_
956# CHECK: A[0] = A_1;)IR";957
958torch::jit::testing::FileCheck().run(verification_pattern, oss.str());959}
960
961TEST(Registerizer, RegisterizerParallelized) {962BufHandle a("A", {1}, kInt);963VarHandle x("x", kInt);964LoopOptions loopOpts;965loopOpts.set_gpu_block_index(0);966StmtPtr stmt = Block::make(967{Store::make(a, {0}, 0),968For::make(969x,9700,97110,972Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}),973loopOpts)});974
975/*976* A[0] = 0;
977* for (int x = 0; x < 10; x++) {
978* A[0] = (A[0]) + x;
979* }
980*/
981
982ASSERT_THROWS_WITH(983registerize(stmt),984"Registerization must occur after parallelism flattening");985}
986
987// Should be able to registerize this since the scalar would exist before the
988// branch.
989TEST(Registerizer, RegisterizerConditionAfter) {990BufHandle a("A", {5}, kInt);991BufHandle b("B", {5}, kInt);992BufHandle c("C", {5}, kInt);993VarHandle x("x", kInt);994
995StmtPtr stmt = Block::make(996{Store::make(a, {x}, Load::make(b, {x})),997Store::make(c, {x}, Load::make(a, {x})),998Cond::make(999CompareSelect::make(x, 5, CompareSelectOperation::kLT),1000Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1001nullptr)});1002
1003/*1004* A[x] = B[x];
1005* C[x] = A[x];
1006* if (x<5 ? 1 : 0) {
1007* A[x] = (A[x]) + 1;
1008* }
1009*/
1010
1011stmt = registerize(stmt);1012
1013/*1014* int A_1 = B[x];
1015* C[x] = A_1;
1016* if (x<5 ? 1 : 0) {
1017* A_1 = A_1 + 1;
1018* }
1019* A[x] = A_1;
1020*/
1021
1022std::ostringstream oss;1023oss << *stmt;1024
1025const std::string& verification_pattern =1026R"IR(1027# CHECK: int A_1 = B[x];
1028# CHECK: C[x] = A_1;
1029# CHECK: if (
1030# CHECK: A_1 = A_1 + 1;
1031# CHECK: A[x] = A_1;)IR";1032
1033torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1034}
1035
1036// Should be able to registerize this since the scalar exists in the same form
1037// after the branch and there is no overlap.
1038TEST(Registerizer, RegisterizerConditionBefore) {1039BufHandle a("A", {5}, kInt);1040BufHandle b("B", {5}, kInt);1041BufHandle c("C", {5}, kInt);1042VarHandle x("x", kInt);1043
1044StmtPtr stmt = Block::make(1045{Cond::make(1046CompareSelect::make(x, 5, CompareSelectOperation::kLT),1047Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1048nullptr),1049Store::make(a, {x}, Load::make(b, {x})),1050Store::make(c, {x}, Load::make(a, {x}))});1051
1052/*1053* if (x<5 ? 1 : 0) {
1054* A[x] = (A[x]) + 1;
1055* }
1056* A[x] = B[x];
1057* C[x] = A[x];
1058*/
1059
1060stmt = registerize(stmt);1061
1062/*1063* int A_ 1 = A[x];
1064* if (x<5 ? 1 : 0) {
1065* A_1 = A_1 + 1;
1066* }
1067* A_1 = B[x];
1068* C[x] = A_1;
1069* A[x] = A_1;
1070*/
1071
1072std::ostringstream oss;1073oss << *stmt;1074
1075const std::string& verification_pattern =1076R"IR(1077# CHECK: int A_1 = A[x];
1078# CHECK: if (
1079# CHECK: A_1 = A_1 + 1;
1080# CHECK: }
1081# CHECK: A_1 = B[x];
1082# CHECK: C[x] = A_1;
1083# CHECK: A[x] = A_1;)IR";1084
1085torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1086}
1087
1088// Should be able to registerize this as the combination of the two above rules.
1089TEST(Registerizer, RegisterizerConditionInside) {1090BufHandle a("A", {5}, kInt);1091BufHandle b("B", {5}, kInt);1092BufHandle c("C", {5}, kInt);1093VarHandle x("x", kInt);1094
1095StmtPtr stmt = Block::make(1096{Store::make(a, {x}, Load::make(b, {x})),1097Store::make(c, {x}, Load::make(a, {x})),1098Cond::make(1099CompareSelect::make(x, 5, CompareSelectOperation::kLT),1100Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1101nullptr),1102Store::make(b, {x}, Load::make(a, {x})),1103Store::make(a, {x}, Load::make(c, {x}))});1104
1105/*1106* A[x] = B[x];
1107* C[x] = A[x];
1108* if (x<5 ? 1 : 0) {
1109* A[x] = (A[x]) + 1;
1110* }
1111* B[x] = A[x];
1112* A[x] = C[x];
1113*/
1114
1115stmt = registerize(stmt);1116
1117/*1118* int A_1 = B[x];
1119* C[x] = A_1;
1120* if (x<5 ? 1 : 0) {
1121* A_1 = A_1 + 1;
1122* }
1123* B[x] = A_1;
1124* A_1 = C[x];
1125* A[x] = A_1;
1126*/
1127
1128std::ostringstream oss;1129oss << *stmt;1130
1131const std::string& verification_pattern =1132R"IR(1133# CHECK: int A_1 = B[x];
1134# CHECK: C[x] = A_1;
1135# CHECK: if (
1136# CHECK: A_1 = A_1 + 1;
1137# CHECK: }
1138# CHECK: B[x] = A_1;
1139# CHECK: A_1 = C[x];
1140# CHECK: A[x] = A_1;)IR";1141
1142torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1143}
1144
1145// An example where an access is cut by an overlapping access inside a
1146// condition, and both sides are large enough to be registerized but cannot be
1147// because there is no safe place to put the initializer or finalizer.
1148TEST(Registerizer, RegisterizerConditionInsideOverlap1) {1149BufHandle a("A", {5}, kInt);1150BufHandle b("B", {5}, kInt);1151BufHandle c("C", {5}, kInt);1152VarHandle x("x", kInt);1153VarHandle y("y", kInt);1154
1155StmtPtr stmt = Block::make(1156// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)1157{Store::make(a, {x}, Load::make(b, {x})),1158Store::make(c, {x}, Load::make(a, {x})),1159Cond::make(1160CompareSelect::make(x, 5, CompareSelectOperation::kLT),1161Block::make({1162Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1163Store::make(a, {0}, 3),1164Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1165}),1166nullptr),1167Store::make(b, {x}, Load::make(a, {x})),1168Store::make(a, {x}, Load::make(c, {x}))});1169
1170/*1171* A[x] = B[x];
1172* C[x] = A[x];
1173* if (x<5 ? 1 : 0) {
1174* A[x] = (A[x]) + 1;
1175* A[0] = 3;
1176* A[x] = (A[x]) + 1;
1177* }
1178* B[x] = A[x];
1179* A[x] = C[x];
1180*/
1181
1182// The A[0] store overlaps, A[x] cutting the region that can be registerized1183// into two groups.1184// Each group has 2 loads and 2 stores however, so we could registerize it,1185// but the first group would need to be finalized inside the condition block,1186// the second would need to be initialized inside the condition block. There's1187// no safe place to put these that's visible to the other uses in the group1188// and so neither registerization is possible.1189
1190std::ostringstream before;1191before << *stmt;1192
1193// No change.1194stmt = registerize(stmt);1195
1196std::ostringstream after;1197after << *stmt;1198
1199ASSERT_EQ(before.str(), after.str());1200}
1201
1202// Same as the above, but the access group before the condition (and after the
1203// condition) are large enough to be registerized without needing the access
1204// from the loop. Registerization occurs but does not include any accesses in
1205// the condition, and the first group must be finalized before the Cond, the
1206// second initialized after it.
1207TEST(Registerizer, RegisterizerConditionInsideOverlap2) {1208BufHandle a("A", {5}, kInt);1209BufHandle b("B", {5}, kInt);1210BufHandle c("C", {5}, kInt);1211VarHandle x("x", kInt);1212VarHandle y("y", kInt);1213
1214StmtPtr stmt = Block::make(1215// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)1216{Store::make(a, {x}, Load::make(b, {x})),1217Store::make(a, {x}, Load::make(b, {x + 1})),1218Store::make(c, {x}, Load::make(a, {x})),1219Cond::make(1220CompareSelect::make(x, 5, CompareSelectOperation::kLT),1221Block::make({1222Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1223Store::make(a, {0}, 3),1224Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1225}),1226nullptr),1227Store::make(b, {x}, Load::make(a, {x})),1228Store::make(b, {x + 1}, Load::make(a, {x})),1229Store::make(a, {x}, Load::make(c, {x}))});1230
1231/*1232* A[x] = B[x];
1233* A[x] = B[x + 1];
1234* C[x] = A[x];
1235* if (x<5 ? 1 : 0) {
1236* A[x] = (A[x]) + 1;
1237* A[0] = 3;
1238* A[x] = (A[x]) + 1;
1239* }
1240* B[x] = A[x];
1241* B[x + 1] = A[x];
1242* A[x] = C[x];
1243*/
1244
1245stmt = registerize(stmt);1246
1247/*1248* int A_1 = B[x]; // A_1 initializer
1249* A_1 = B[x + 1]; //
1250* C[x] = A_1; //
1251* A[x] = A_1; // A_1 finalizer
1252* if (x<5 ? 1 : 0) {
1253* A[x] = (A[x]) + 1;
1254* A[0] = 3;
1255* A[x] = (A[x]) + 1;
1256* }
1257* int A_2 = A[x]; // A_2 initialier
1258* B[x] = A_2; //
1259* B[x + 1] = A_2; //
1260* A_2 = C[x]; //
1261* A[x] = A_2; // A_2 finalizer
1262*/
1263
1264std::ostringstream oss;1265oss << *stmt;1266
1267const std::string& verification_pattern =1268R"IR(1269# CHECK: int A_1 = B[x];
1270# CHECK: A_1 = B[x + 1];
1271# CHECK: C[x] = A_1;
1272# CHECK: A[x] = A_1;
1273# CHECK: if (
1274# CHECK-NOT: A_1 = A_1 + 1;
1275# CHECK: A[x] = (A[x]
1276# CHECK: A[0] =
1277# CHECK: A[x] = (A[x]
1278# CHECK: }
1279# CHECK: int A_2 = A[x];
1280# CHECK: B[x] = A_2;
1281# CHECK: B[x + 1] = A_2;
1282# CHECK: A_2 = C[x];
1283# CHECK: A[x] = A_2;)IR";1284
1285torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1286}
1287
1288// When accesses are within conditional blocks they are not visible to the wider
1289// program, because we don't know if the branch would be taken and if it isn't
1290// the accesses in it don't need to be valid (think size checks on the index).
1291// In this case the accesses cannot be registerized.
1292TEST(Registerizer, RegisterizerConditionHidden) {1293BufHandle a("A", {5}, kInt);1294BufHandle b("B", {5}, kInt);1295BufHandle c("C", {5}, kInt);1296VarHandle x("x", kInt);1297
1298StmtPtr stmt = Block::make(1299{Cond::make(1300CompareSelect::make(x, 5, CompareSelectOperation::kLT),1301Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1302nullptr),1303Cond::make(1304CompareSelect::make(x, 5, CompareSelectOperation::kGT),1305Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1306nullptr)});1307
1308/*1309* if (x<5 ? 1 : 0) {
1310* A[x] = (A[x]) + 1;
1311* }
1312* if (x>5 ? 1 : 0) {
1313* A[x] = (A[x]) + 1;
1314* }
1315*/
1316
1317std::ostringstream before;1318before << *stmt;1319
1320// No change.1321stmt = registerize(stmt);1322
1323std::ostringstream after;1324after << *stmt;1325
1326ASSERT_EQ(before.str(), after.str());1327}
1328
1329// But... if the same access is found in a non conditional scope, that means
1330// that that access is valid in the higher scope (or at least if its not it's
1331// the user's fault). It "unhides" the conditional accesses, allowing
1332// registerization to occur.
1333TEST(Registerizer, RegisterizerConditionUnhidden) {1334BufHandle a("A", {5}, kInt);1335BufHandle b("B", {5}, kInt);1336BufHandle c("C", {5}, kInt);1337VarHandle x("x", kInt);1338
1339StmtPtr stmt = Block::make(1340{Cond::make(1341CompareSelect::make(x, 5, CompareSelectOperation::kLT),1342Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1343nullptr),1344Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1345Cond::make(1346CompareSelect::make(x, 5, CompareSelectOperation::kGT),1347Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1348nullptr)});1349
1350/*1351* if (x<5 ? 1 : 0) {
1352* A[x] = (A[x]) + 1;
1353* }
1354* A[x] = (A[x]) + 1; <-- this is doing the unhiding.
1355* if (x>5 ? 1 : 0) {
1356* A[x] = (A[x]) + 1;
1357* }
1358*/
1359
1360stmt = registerize(stmt);1361
1362/*1363* int A_1 = A[x];
1364* if (x<5 ? 1 : 0) {
1365* A_1 = A_1 + 1;
1366* }
1367* A_1 = A_1 + 1;
1368* if (x>5 ? 1 : 0) {
1369* A_1 = A_1 + 1;
1370* }
1371* A[x] = A_1;
1372*/
1373
1374std::ostringstream oss;1375oss << *stmt;1376
1377const std::string& verification_pattern =1378R"IR(1379# CHECK: int A_1 = A[x];
1380# CHECK: if (x<5
1381# CHECK: A_1 = A_1 + 1;
1382# CHECK: }
1383# CHECK: A_1 = A_1 + 1;
1384# CHECK: if (x>5
1385# CHECK: A_1 = A_1 + 1;
1386# CHECK: }
1387# CHECK: A[x] = A_1;)IR";1388
1389torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1390}
1391
1392// Can registerize a load that occurs in the condition of a Cond.
1393TEST(Registerizer, RegisterizerCondCondition) {1394BufHandle a("A", {5}, kInt);1395BufHandle b("B", {5}, kInt);1396BufHandle c("C", {5}, kInt);1397VarHandle x("x", kInt);1398
1399StmtPtr stmt = Block::make(1400{Store::make(a, {x}, Load::make(b, {x})),1401Store::make(c, {x}, Load::make(a, {x})),1402Cond::make(1403CompareSelect::make(1404Load::make(a, {x}), 5, CompareSelectOperation::kLT),1405Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),1406nullptr)});1407
1408/*1409* A[x] = B[x];
1410* C[x] = A[x];
1411* if ((A[x])<5 ? 1 : 0) {
1412* C[x] = (C[x]) + 1;
1413* }
1414*/
1415
1416stmt = registerize(stmt);1417
1418/*1419* int A_1 = B[x];
1420* int C_1 = A_1;
1421* if (A_1<5 ? 1 : 0) {
1422* C_1 = C_1 + 1;
1423* }
1424* C[x] = C_1;
1425*/
1426
1427std::ostringstream oss;1428oss << *stmt;1429
1430const std::string& verification_pattern =1431R"IR(1432# CHECK: int A_1 = B[x];
1433# CHECK: int C_1 = A_1;
1434# CHECK: if (A_1<5
1435# CHECK: C_1 = C_1 + 1;
1436# CHECK: C[x] = C_1;)IR";1437
1438torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1439}
1440
1441// Appearing in the condition of a Cond makes it visible to the enclosing scope,
1442// and so we can registerize internal usages.
1443TEST(Registerizer, RegisterizerCondConditionUnhidden) {1444BufHandle a("A", {5}, kInt);1445BufHandle b("B", {5}, kInt);1446BufHandle c("C", {5}, kInt);1447VarHandle x("x", kInt);1448
1449StmtPtr stmt = Block::make({Cond::make(1450CompareSelect::make(Load::make(a, {x}), 5, CompareSelectOperation::kLT),1451Store::make(a, {x}, Add::make(Load::make(a, {x}), 1)),1452Store::make(a, {x}, Add::make(Load::make(a, {x}), 10)))});1453
1454/*1455* if ((A[x])<5 ? 1 : 0) {
1456* A[x] = (A[x]) + 1;
1457* } else {
1458* A[x] = (A[x]) + 10;
1459* }
1460*/
1461
1462stmt = registerize(stmt);1463
1464/*1465* int A_1 = A[x];
1466* if (A_1<5 ? 1 : 0) {
1467* A_1 = A_1 + 1;
1468* } else {
1469* A_1 = A_1 + 10;
1470* }
1471* A[x] = A_1;
1472*/
1473
1474std::ostringstream oss;1475oss << *stmt;1476
1477const std::string& verification_pattern =1478R"IR(1479# CHECK: int A_1 = A[x];
1480# CHECK: if (A_1<5
1481# CHECK: A_1 = A_1 + 1;
1482# CHECK: } else {
1483# CHECK: A_1 = A_1 + 10;
1484# CHECK: }
1485# CHECK: A[x] = A_1;)IR";1486
1487torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1488}
1489
1490// Conditional hiding also works for IfThenElse exprs.
1491TEST(Registerizer, RegisterizerIfThenElseHidden) {1492BufHandle a("A", {5}, kInt);1493BufHandle b("B", {5}, kInt);1494BufHandle c("C", {5}, kInt);1495VarHandle x("x", kInt);1496VarHandle y("y", kInt);1497
1498StmtPtr stmt = Block::make(1499{Store::make(1500b,1501{y},1502IfThenElse::make(1503CompareSelect::make(x, 5, CompareSelectOperation::kLT),1504Add::make(Load::make(a, {x}), 1),1505Add::make(Load::make(a, {x + 1}), 2))),1506Store::make(1507b,1508{y + 1},1509IfThenElse::make(1510CompareSelect::make(x, 5, CompareSelectOperation::kLT),1511Add::make(Load::make(a, {x}), 1),1512Add::make(Load::make(a, {x + 1}), 2)))});1513
1514/*1515* B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1516* B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1517*/
1518
1519std::ostringstream before;1520before << *stmt;1521
1522// No change.1523stmt = registerize(stmt);1524
1525std::ostringstream after;1526after << *stmt;1527
1528ASSERT_EQ(before.str(), after.str());1529}
1530
1531// Conditional unhiding also works for IfThenElse exprs.
1532TEST(Registerizer, RegisterizerIfThenElseUnhidden) {1533BufHandle a("A", {5}, kInt);1534BufHandle b("B", {5}, kInt);1535BufHandle c("C", {5}, kInt);1536VarHandle x("x", kInt);1537VarHandle y("y", kInt);1538
1539StmtPtr stmt = Block::make({1540Store::make(a, {x}, 0),1541Store::make(1542b,1543{y},1544IfThenElse::make(1545CompareSelect::make(x, 5, CompareSelectOperation::kLT),1546Add::make(Load::make(a, {x}), 1),1547Add::make(Load::make(a, {x + 1}), 2))),1548Store::make(1549b,1550{y + 1},1551IfThenElse::make(1552CompareSelect::make(x, 5, CompareSelectOperation::kLT),1553Add::make(Load::make(a, {x}), 1),1554Add::make(Load::make(a, {x + 1}), 2))),1555});1556
1557/*1558* A[x] = 0;
1559* B[y] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1560* B[y + 1] = IfThenElse(x<5 ? 1 : 0, (A[x]) + 1, (A[x + 1]) + 2);
1561*/
1562
1563stmt = registerize(stmt);1564
1565/*1566* int A_1 = 0;
1567* B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1568* B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1569* A[x] = A_1;
1570*/
1571
1572std::ostringstream oss;1573oss << *stmt;1574
1575const std::string& verification_pattern =1576R"IR(1577# CHECK: int A_1 = 0;
1578# CHECK: B[y] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1579# CHECK: B[y + 1] = IfThenElse(x<5 ? 1 : 0, A_1 + 1, (A[x + 1]) + 2);
1580# CHECK: A[x] = A_1;)IR";1581
1582torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1583}
1584
1585// Nested IfThenElse exprs can't promote to higher level scopes.
1586TEST(Registerizer, RegisterizerIfThenElseNested) {1587BufHandle a("A", {5}, kInt);1588BufHandle b("B", {5}, kInt);1589BufHandle c("C", {5}, kInt);1590BufHandle d("D", {5}, kInt);1591VarHandle x("x", kInt);1592
1593StmtPtr stmt = Block::make({Store::make(1594a,1595{x},1596IfThenElse::make(1597CompareSelect::make(x, 3, CompareSelectOperation::kLT),1598IfThenElse::make(1599CompareSelect::make(x, 2, CompareSelectOperation::kEQ),1600Load::make(d, {x}),1601Load::make(b, {x})),1602IfThenElse::make(1603CompareSelect::make(x, 5, CompareSelectOperation::kEQ),1604Load::make(c, {x}),1605Load::make(d, {x}))))});1606
1607/*1608* A[x] = IfThenElse(x<3 ? 1 : 0,
1609* IfThenElse(x==2 ? 1 : 0, D[x], B[x]),
1610* IfThenElse(x==5 ? 1 : 0, C[x], D[x]));
1611*/
1612
1613std::ostringstream before;1614before << *stmt;1615
1616// No change.1617stmt = registerize(stmt);1618
1619std::ostringstream after;1620after << *stmt;1621
1622ASSERT_EQ(before.str(), after.str());1623}
1624
1625// Cannot registerize an access completely contained within an IfThenElse
1626// branch, since it is not a Stmt and cannot hold variable definitions. We need
1627// to check that we don't promote the initializer/finalizer to the enclosing
1628// Block.
1629TEST(Registerizer, RegisterizerIfThenElseInternal) {1630// Making these floats so they don't get simplified to a single access.1631BufHandle a("A", {5}, kFloat);1632BufHandle b("B", {5}, kFloat);1633VarHandle x("x", kInt);1634
1635StmtPtr stmt = Block::make({Store::make(1636a,1637{x},1638IfThenElse::make(1639CompareSelect::make(x, 3, CompareSelectOperation::kLT),1640Add::make(Load::make(b, {x}), Load::make(b, {x})),1641Load::make(b, {x})))});1642
1643/*1644* A[x] = IfThenElse(x<3 ? 1 : 0, (B[x]) + (B[x]), B[x]);
1645*/
1646
1647std::ostringstream before;1648before << *stmt;1649
1650// No change.1651stmt = registerize(stmt);1652
1653std::ostringstream after;1654after << *stmt;1655
1656ASSERT_EQ(before.str(), after.str());1657
1658// If this was a Cond instead of an IfThenElse then we could registerize the1659// two accesses to B[x] in the True branch.1660
1661// Actually lets verify that.1662
1663stmt = Block::make({Cond::make(1664CompareSelect::make(x, 3, CompareSelectOperation::kLT),1665Store::make(a, {x}, Add::make(Load::make(b, {x}), Load::make(b, {x}))),1666Store::make(a, {x}, Load::make(b, {x})))});1667
1668/*1669* if (x<3 ? 1 : 0) {
1670* A[x] = (B[x]) + (B[x]);
1671* } else {
1672* A[x] = B[x];
1673* }
1674*/
1675
1676stmt = registerize(stmt);1677
1678/*1679* if (x<3 ? 1 : 0) {
1680* float B_1 = B[x];
1681* A[x] = B_1 + B_1;
1682* } else {
1683* A[x] = B[x];
1684* }
1685*/
1686
1687std::ostringstream oss;1688oss << *stmt;1689
1690const std::string& verification_pattern =1691R"IR(1692# CHECK-NOT: int
1693# CHECK-NOT: float
1694# CHECK: if (x<3
1695# CHECK: float B_1 =
1696# CHECK: A[x] = B_1 + B_1
1697# CHECK: } else {
1698# CHECK: A[x] = B[x]
1699# CHECK: }
1700# CHECK-NOT: A[x]
1701# CHECK-NOT: B[x])IR";1702
1703torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1704}
1705
1706// Can registerize a load that occurs in the condition of an IfThenElse;
1707TEST(Registerizer, RegisterizerIfThenElseCondition) {1708BufHandle a("A", {5}, kInt);1709BufHandle b("B", {5}, kInt);1710BufHandle c("C", {5}, kInt);1711VarHandle x("x", kInt);1712
1713StmtPtr stmt = Block::make(1714{Store::make(a, {x}, Load::make(a, {x})),1715Store::make(1716a,1717{x},1718IfThenElse::make(1719CompareSelect::make(1720Load::make(a, {x}), 5, CompareSelectOperation::kLT),1721Load::make(b, {0}),1722Load::make(c, {0})))});1723
1724/*1725* A[x] = A[x]; <---- just here so there are enough accesses to combine.
1726* A[x] = IfThenElse((A[x])<5 ? 1 : 0, B[0], C[0]);
1727*/
1728
1729stmt = registerize(stmt);1730
1731/*1732* int A_1 = A[x];
1733* A_1 = A_1;
1734* A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
1735* A[x] = A_1;
1736*/
1737
1738std::ostringstream oss;1739oss << *stmt;1740
1741const std::string& verification_pattern =1742R"IR(1743# CHECK: int A_1 = A[x];
1744# CHECK: A_1 = IfThenElse(A_1<5 ? 1 : 0, B[0], C[0]);
1745# CHECK: A[x] = A_1;)IR";1746
1747torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1748}
1749
1750// Appearing in the condition of a Cond makes it visible to the enclosing scope,
1751// and so we can registerize internal usages.
1752TEST(Registerizer, RegisterizerIfThenElseConditionUnhidden) {1753BufHandle a("A", {5}, kInt);1754BufHandle b("B", {5}, kInt);1755BufHandle c("C", {5}, kInt);1756VarHandle x("x", kInt);1757
1758StmtPtr stmt = Block::make({Store::make(1759b,1760{x},1761IfThenElse::make(1762CompareSelect::make(1763Load::make(a, {x}), 5, CompareSelectOperation::kLT),1764Add::make(Load::make(a, {x}), 1),1765Add::make(Load::make(a, {x}), 10)))});1766
1767/*1768* B[x] = IfThenElse((A[x])<5 ? 1 : 0, (A[x]) + 1, (A[x]) + 10);
1769*/
1770
1771stmt = registerize(stmt);1772
1773/*1774* int A_1 = A[x];
1775* B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);
1776*/
1777
1778std::ostringstream oss;1779oss << *stmt;1780
1781const std::string& verification_pattern =1782R"IR(1783# CHECK: int A_1 = A[x];
1784# CHECK: B[x] = IfThenElse(A_1<5 ? 1 : 0, A_1 + 1, A_1 + 10);)IR";1785
1786torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1787}
1788
1789// Cannot promote accesses internal to IfThenElse branches even if the enclosing
1790// scope if conditional.
1791TEST(Registerizer, RegisterizerConditionBranchOnly) {1792BufHandle a("A", {5}, kInt);1793VarHandle x("x", kInt);1794StmtPtr stmt = Block::make({For::make(1795x,17960,179710,1798Block::make({1799Cond::make(1800CompareSelect::make(x, 5, CompareSelectOperation::kLT),1801Store::make(1802a,1803{x},1804IfThenElse::make(1805CompareSelect::make(x, 5, CompareSelectOperation::kLT),1806Add::make(Load::make(a, {x}), x),1807Add::make(Load::make(a, {x - 5}), x))),1808Store::make(1809a,1810{x - 5},1811IfThenElse::make(1812CompareSelect::make(x, 5, CompareSelectOperation::kLT),1813Add::make(Load::make(a, {x}), x),1814Add::make(Load::make(a, {x - 5}), x)))),1815}))});1816stmt = IRSimplifier::simplify(stmt);1817
1818std::ostringstream before;1819before << *stmt;1820
1821/* for (int x = 0; x < 10; x++) {1822* if (x<5 ? 1 : 0) {
1823* A[x] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
1824* } else {
1825* A[x - 5] = IfThenElse(x<5 ? 1 : 0, (A[x]) + x, (A[x - 5]) + x);
1826* }
1827* }
1828*/
1829
1830// No change.1831stmt = registerize(stmt);1832
1833std::ostringstream after;1834after << *stmt;1835
1836ASSERT_EQ(before.str(), after.str());1837}
1838
1839// We can registerize an IfThenElse that appears in the condition branch of a
1840// Cond. This is a weird but valid thing to do.
1841TEST(Registerizer, RegisterizerCondIfThenElse) {1842BufHandle a("A", {5}, kInt);1843BufHandle b("B", {5}, kInt);1844BufHandle c("C", {5}, kInt);1845VarHandle x("x", kInt);1846
1847StmtPtr stmt = Block::make({Cond::make(1848CompareSelect::make(1849IfThenElse::make(1850CompareSelect::make(1851Load::make(a, {x}), 5, CompareSelectOperation::kLT),1852Load::make(a, {x}),1853Load::make(b, {x})),1854x,1855CompareSelectOperation::kEQ),1856Store::make(c, {x}, Add::make(Load::make(c, {x}), 1)),1857nullptr)});1858
1859/*1860* if ((IfThenElse((A[x])<5 ? 1 : 0, A[x], B[x]))==x ? 1 : 0) {
1861* C[x] = (C[x]) + 1;
1862* }
1863*/
1864
1865stmt = registerize(stmt);1866
1867// access to A can be registerized, but not B or C1868
1869/*1870* int A_1 = A[x];
1871* if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]))==x ? 1 : 0) {
1872* C[x] = (C[x]) + 1;
1873* }
1874*/
1875
1876std::ostringstream oss;1877oss << *stmt;1878
1879const std::string& verification_pattern =1880R"IR(1881# CHECK: int A_1 = A[x];
1882# CHECK: if ((IfThenElse(A_1<5 ? 1 : 0, A_1, B[x]
1883# CHECK: C[x] = (C[x]) + 1;)IR";1884
1885torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1886}
1887
1888// Can registerize a conditional access in the RHS of a store unhidden by it's
1889// LHS, and hoist it out of a loop.
1890TEST(Registerizer, RegisterizerIfThenElseLoop) {1891BufHandle a("A", {5}, kInt);1892BufHandle b("B", {5}, kInt);1893VarHandle x("x", kInt);1894VarHandle y("y", kInt);1895
1896StmtPtr stmt = For::make(1897y,18980,189910,1900Store::make(1901a,1902{x},1903IfThenElse::make(1904CompareSelect::make(x, 3, CompareSelectOperation::kLT),1905Load::make(a, {x}),1906Load::make(b, {y}))));1907
1908/*1909* for (int y = 0; y < 10; y++) {
1910* A[x] = IfThenElse(x<3 ? 1 : 0, A[x], B[y]);
1911* }
1912*/
1913
1914stmt = registerize(stmt);1915
1916/*1917* int A_1 = A[x];
1918* for (int y = 0; y < 10; y++) {
1919* A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
1920* }
1921* A[x] = A_1;
1922*/
1923
1924std::ostringstream oss;1925oss << *stmt;1926
1927const std::string& verification_pattern =1928R"IR(1929# CHECK: int A_1 = A[x];
1930# CHECK: for (
1931# CHECK: A_1 = IfThenElse(x<3 ? 1 : 0, A_1, B[y]);
1932# CHECK: }
1933# CHECK: A[x] = A_1;)IR";1934
1935torch::jit::testing::FileCheck().run(verification_pattern, oss.str());1936}
1937
1938// Cannot registerize if the RHS overlaps the access creating visibility.
1939TEST(Registerizer, RegisterizerIfThenElseLoopCut) {1940BufHandle a("A", {5}, kInt);1941BufHandle b("B", {5}, kInt);1942VarHandle x("x", kInt);1943VarHandle y("y", kInt);1944
1945StmtPtr stmt = Block::make({For::make(1946y,19470,194810,1949Store::make(1950a,1951{x},1952IfThenElse::make(1953CompareSelect::make(x, 3, CompareSelectOperation::kLT),1954Load::make(a, {x}),1955Load::make(a, {y}))))});1956
1957/*1958* for (int y = 0; y < 10; y++) {
1959* A[x] = IfThenElse(x<3 ? 1 : 0, A[x], A[y]);
1960* }
1961*/
1962
1963std::ostringstream before;1964before << *stmt;1965
1966// No change.1967stmt = registerize(stmt);1968
1969std::ostringstream after;1970after << *stmt;1971
1972ASSERT_EQ(before.str(), after.str());1973}
1974
1975// Simple case where an access is cut by an overlapping access later in the
1976// program, we can registerize up until the overlap.
1977TEST(Registerizer, RegisterizerPartialAfter) {1978BufHandle a("A", {1}, kInt);1979VarHandle x("x", kInt);1980StmtPtr stmt = Block::make(1981{Store::make(a, {0}, 0),1982For::make(1983x,19840,198510,1986Block::make(1987{Store::make(a, {0}, Add::make(Load::make(a, {0}), x))})),1988For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1})))});1989
1990/*1991* A[0] = 0;
1992* for (int x = 0; x < 10; x++) {
1993* A[0] = (A[0]) + x;
1994* }
1995* for (int x = 1; x < 10; x++) {
1996* A[x] = A[x - 1];
1997* }
1998*/
1999
2000stmt = registerize(stmt);2001
2002/*2003* int A_1 = 0;
2004* for (int x = 0; x < 10; x++) {
2005* A_1 = A_1 + x;
2006* }
2007* A[0] = A_1;
2008* for (int x = 1; x < 10; x++) {
2009* A[x] = A[x - 1];
2010* }
2011*/
2012
2013std::ostringstream oss;2014oss << *stmt;2015
2016const std::string& verification_pattern =2017R"IR(2018# CHECK: int A_1 = 0;
2019# CHECK: for (
2020# CHECK: A_1 = A_1 + x;
2021# CHECK: }
2022# CHECK: A[0] = A_1;
2023# CHECK: for (
2024# CHECK: A[x] = A[x - 1];
2025# CHECK: }
2026# CHECK-NOT: A)IR";2027
2028torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2029}
2030
2031// We can registerize an access which overlaps a previous access, the
2032// initializer must be inserted after the previous access.
2033TEST(Registerizer, RegisterizerPartialBefore) {2034BufHandle a("A", {1}, kInt);2035VarHandle x("x", kInt);2036StmtPtr stmt = Block::make(2037{For::make(x, 1, 10, Store::make(a, {x}, Load::make(a, {x - 1}))),2038Store::make(a, {0}, 0),2039For::make(2040x,20410,204210,2043Block::make(2044{Store::make(a, {0}, Add::make(Load::make(a, {0}), x))}))});2045
2046/*2047* for (int x = 1; x < 10; x++) {
2048* A[x] = A[x - 1];
2049* }
2050* A[0] = 0;
2051* for (int x = 0; x < 10; x++) {
2052* A[0] = (A[0]) + x;
2053* }
2054*/
2055
2056stmt = registerize(stmt);2057
2058/*2059* for (int x = 1; x < 10; x++) {
2060* A[x] = A[x - 1];
2061* }
2062* int A_1 = 0;
2063* for (int x = 0; x < 10; x++) {
2064* A_1 = A_1 + x;
2065* }
2066* A[0] = A_1;
2067*/
2068
2069std::ostringstream oss;2070oss << *stmt;2071
2072const std::string& verification_pattern =2073R"IR(2074# CHECK-NOT: int
2075# CHECK: for (
2076# CHECK: A[x] = A[x - 1];
2077# CHECK: }
2078# CHECK: int A_1 = 0;
2079# CHECK: for (
2080# CHECK: A_1 = A_1 + x;
2081# CHECK: }
2082# CHECK: A[0] = A_1;)IR";2083
2084torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2085}
2086
2087// The combination of the previous two tests, an access is cut by an overlapping
2088// access in both directions.
2089TEST(Registerizer, RegisterizerPartialInside) {2090BufHandle a("A", {1}, kInt);2091VarHandle x1("x1", kInt);2092VarHandle x2("x2", kInt);2093VarHandle x3("x3", kInt);2094StmtPtr stmt = Block::make(2095{Store::make(a, {0}, 2),2096For::make(2097x1, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x1))),2098For::make(x2, 1, 10, Store::make(a, {x2}, Load::make(a, {x2 - 1}))),2099For::make(2100x3, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x3)))});2101
2102/*2103* A[0] = 2;
2104* for (int x1 = 0; x1 < 10; x1++) {
2105* A[0] = (A[0]) + x1;
2106* }
2107* for (int x2 = 1; x2 < 10; x2++) {
2108* A[x2] = A[x2 - 1];
2109* }
2110* for (int x3 = 0; x3 < 10; x3++) {
2111* A[0] = (A[0]) + x3;
2112* }
2113*/
2114
2115stmt = registerize(stmt);2116
2117/*2118* int A_1 = 2;
2119* for (int x1 = 0; x1 < 10; x1++) {
2120* A_1 = A_1 + x1;
2121* }
2122* A[0] = A_1;
2123* for (int x2 = 1; x2 < 10; x2++) {
2124* A[x2] = A[x2 - 1];
2125* }
2126* int A_2 = A[0];
2127* for (int x3 = 0; x3 < 10; x3++) {
2128* A_2 = A_2 + x3;
2129* }
2130* A[0] = A_2;
2131*/
2132
2133std::ostringstream oss;2134oss << *stmt;2135
2136const std::string& verification_pattern =2137R"IR(2138# CHECK: int A_1 = 2;
2139# CHECK: for (
2140# CHECK: A_1 = A_1 + x1;
2141# CHECK: }
2142# CHECK: A[0] = A_1;
2143# CHECK: for (
2144# CHECK: A[x2] =
2145# CHECK: }
2146# CHECK: int A_2 = A[0];
2147# CHECK: for (
2148# CHECK: A_2 = A_2 + x3;
2149# CHECK: }
2150# CHECK: A[0] = A_2;)IR";2151
2152torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2153}
2154
2155// An element could be registerized program wide but is cut by a conditional
2156// access, we should break this into two scalars and write back to the buffer
2157// before the condition.
2158TEST(Registerizer, RegisterizerPartialCondition) {2159BufHandle a("A", {1}, kInt);2160VarHandle x("x", kInt);2161StmtPtr stmt = Block::make(2162{Store::make(a, {0}, 2),2163For::make(2164x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x))),2165Cond::make(2166CompareSelect::make(x, 5, CompareSelectOperation::kLT),2167Store::make(a, {x}, Load::make(a, {x - 1})),2168nullptr),2169For::make(2170x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), x)))});2171
2172/*2173* A[0] = 2;
2174* for (int x = 0; x < 10; x++) {
2175* A[0] = (A[0]) + x;
2176* }
2177* if (x<5 ? 1 : 0) {
2178* A[x] = A[x - 1];
2179* }
2180* for (int x = 0; x < 10; x++) {
2181* A[0] = (A[0]) + x;
2182* }
2183*/
2184
2185stmt = registerize(stmt);2186
2187/*2188* int A_1 = 2;
2189* for (int x = 0; x < 10; x++) {
2190* A_1 = A_1 + x;
2191* }
2192* A[0] = A_1;
2193* if (x<5 ? 1 : 0) {
2194* A[x] = A[x - 1];
2195* }
2196* int A_2 = A[0];
2197* for (int x = 0; x < 10; x++) {
2198* A_2 = A_2 + x;
2199* }
2200* A[0] = A_2;
2201*/
2202
2203std::ostringstream oss;2204oss << *stmt;2205
2206const std::string& verification_pattern =2207R"IR(2208# CHECK: int A_1 = 2;
2209# CHECK: for (
2210# CHECK: A_1 = A_1 + x;
2211# CHECK: }
2212# CHECK: A[0] = A_1;
2213# CHECK: if (
2214# CHECK: A[x] =
2215# CHECK: }
2216# CHECK: int A_2 = A[0];
2217# CHECK: for (
2218# CHECK: A_2 = A_2 + x;
2219# CHECK: }
2220# CHECK: A[0] = A_2;)IR";2221
2222torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2223}
2224
2225// Tests case where an access is cut by an internal conditional access which
2226// itself is registerized.
2227TEST(Registerizer, RegisterizerPartialConditionInternalCut) {2228BufHandle a("A", {1}, kInt);2229VarHandle x("x", kInt);2230StmtPtr stmt = Block::make(2231{Store::make(a, {0}, 1),2232Store::make(a, {0}, 3),2233Cond::make(2234CompareSelect::make(x, 5, CompareSelectOperation::kLT),2235Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),2236nullptr),2237Store::make(a, {0}, 4),2238Store::make(a, {0}, 6)});2239
2240/*2241* A[0] = 1;
2242* A[0] = 3;
2243* if (x<5 ? 1 : 0) {
2244* A[x] = 1;
2245* A[x] = 3;
2246* }
2247* A[0] = 4;
2248* A[0] = 6;
2249*/
2250
2251stmt = registerize(stmt);2252
2253/*2254* int A_1 = 1;
2255* A_1 = 3;
2256* A[0] = A_1;
2257* if (x<5 ? 1 : 0) {
2258* int A_2 = 1;
2259* A_2 = 3;
2260* A[x] = A_2;
2261* }
2262* int A_3 = 4;
2263* A_3 = 6;
2264* A[0] = A_3;
2265*/
2266
2267std::ostringstream oss;2268oss << *stmt;2269
2270const std::string& verification_pattern =2271R"IR(2272# CHECK: int A_1 = 1;
2273# CHECK: A_1 = 3
2274# CHECK: A[0] = A_1;
2275# CHECK: if (
2276# CHECK: int A_2 = 1;
2277# CHECK: A_2 = 3;
2278# CHECK: A[x] = A_2;
2279# CHECK: }
2280# CHECK: int A_3 = 4;
2281# CHECK: A_3 = 6;
2282# CHECK: A[0] = A_3;)IR";2283
2284torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2285}
2286
2287// First statement in condition closes outer access, but can be registerized
2288// with later statements.
2289TEST(Registerizer, RegisterizerPartialConditionInternalStart) {2290BufHandle a("A", {1}, kInt);2291VarHandle x("x", kInt);2292StmtPtr stmt = Block::make(2293{Store::make(a, {0}, 1),2294Store::make(a, {0}, 3),2295Cond::make(2296CompareSelect::make(x, 5, CompareSelectOperation::kLT),2297Block::make({Store::make(a, {x}, 1), Store::make(a, {x}, 3)}),2298nullptr),2299Store::make(a, {x}, 4),2300Store::make(a, {x}, 6)});2301
2302/*2303* A[0] = 1;
2304* A[0] = 3;
2305* if (x<5 ? 1 : 0) {
2306* A[x] = 1;
2307* A[x] = 3;
2308* }
2309* A[x] = 4;
2310* A[x] = 6;
2311*/
2312
2313stmt = registerize(stmt);2314
2315/*2316* int A_1 = 1;
2317* A_1 = 3;
2318* A[0] = A_1;
2319* int A_2 = A[x]; <--- must read from the input here.
2320* if (x<5 ? 1 : 0) {
2321* A_2 = 1;
2322* A_2 = 3;
2323* }
2324* A_2 = 4;
2325* A_2 = 6;
2326* A[x] = A_2;
2327*/
2328
2329// TODO: I suppose we could refactor with a conditional initializer?2330
2331std::ostringstream oss;2332oss << *stmt;2333
2334const std::string& verification_pattern =2335R"IR(2336# CHECK: int A_1 = 1;
2337# CHECK: A_1 = 3
2338# CHECK: A[0] = A_1;
2339# CHECK: int A_2 = A[x];
2340# CHECK: if (
2341# CHECK: A_2 = 1;
2342# CHECK: A_2 = 3;
2343# CHECK: }
2344# CHECK: A_2 = 4;
2345# CHECK: A_2 = 6;
2346# CHECK: A[x] = A_2;)IR";2347
2348torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2349}
2350
2351// An access cuts two open overlaps and creates four scalar variables.
2352TEST(Registerizer, RegisterizerPartialOverlapsTwo) {2353BufHandle a("A", {1}, kInt);2354VarHandle x("x", kInt);2355StmtPtr stmt = Block::make(2356{Store::make(a, {1}, Load::make(a, {0})),2357Store::make(a, {0}, Load::make(a, {1})),2358Store::make(a, {0}, Load::make(a, {1})),2359For::make(x, 1, 10, Store::make(a, {x}, x)),2360Store::make(a, {1}, Load::make(a, {0})),2361Store::make(a, {0}, Load::make(a, {1})),2362Store::make(a, {0}, Load::make(a, {1}))});2363
2364/*2365* A[1] = A[0];
2366* A[0] = A[1];
2367* A[0] = A[1];
2368* for (int x = 1; x < 10; x++) {
2369* A[x] = x;
2370* }
2371* A[1] = A[0];
2372* A[0] = A[1];
2373* A[0] = A[1];
2374*/
2375
2376stmt = registerize(stmt);2377
2378/*2379* int A_1 = A[0];
2380* int A_2 = A_1;
2381* A_1 = A_2;
2382* A_1 = A_2;
2383* A[1] = A_2;
2384* A[0] = A_1;
2385* for (int x = 1; x < 10; x++) {
2386* A[x] = x;
2387* }
2388* int A_3 = A[0];
2389* int A_4 = A_3;
2390* A_3 = A_4;
2391* A_3 = A_4;
2392* A[1] = A_4;
2393* A[0] = A_3;
2394*/
2395
2396std::ostringstream oss;2397oss << *stmt;2398
2399const std::string& verification_pattern =2400R"IR(2401# CHECK: int A_1 = A[0];
2402# CHECK: int A_2 = A_1;
2403# CHECK: A_1 = A_2;
2404# CHECK: A_1 = A_2;
2405# CHECK: A[1] = A_2;
2406# CHECK: A[0] = A_1;
2407# CHECK: for (
2408# CHECK: A[x] = x;
2409# CHECK: }
2410# CHECK: int A_3 = A[0];
2411# CHECK: int A_4 = A_3;
2412# CHECK: A_3 = A_4;
2413# CHECK: A_3 = A_4;
2414# CHECK: A[1] = A_4;
2415# CHECK: A[0] = A_3;)IR";2416
2417torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2418}
2419
2420// Nested blocks will automatically be flattened and do not provent
2421// registerization of enclosed accesses.
2422TEST(Registerizer, RegisterizerNestedBlocks) {2423BufHandle a("A", {1}, kInt);2424VarHandle x("x", kInt);2425StmtPtr stmt = Block::make(2426// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)2427{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2428Block::make({Store::make(a, {0}, Add::make(Load::make(a, {0}), 2))}),2429Block::make(2430{Store::make(a, {0}, Add::make(Load::make(a, {0}), 3)),2431Block::make(2432{Store::make(a, {0}, Add::make(Load::make(a, {0}), 4))})})});2433
2434/*2435* A[0] = (A[0]) + 1;
2436* {
2437* A[0] = (A[0]) + 2;
2438* }
2439* {
2440* A[0] = (A[0]) + 3;
2441* {
2442* A[0] = (A[0]) + 4;
2443* }
2444* }
2445*/
2446
2447stmt = registerize(stmt);2448
2449/*2450* int A_1 = A[0];
2451* A_1 = A_1 + 1;
2452* A_1 = A_1 + 2;
2453* A_1 = A_1 + 3;
2454* A_1 = A_1 + 4;
2455* A[0] = A_1;
2456*/
2457
2458std::ostringstream oss;2459oss << *stmt;2460
2461const std::string& verification_pattern =2462R"IR(2463# CHECK: int A_1 = A[0];
2464# CHECK: A_1 = A_1 + 1;
2465# CHECK: A_1 = A_1 + 2;
2466# CHECK: A_1 = A_1 + 3;
2467# CHECK: A_1 = A_1 + 4;
2468# CHECK: A[0] = A_1;)IR";2469
2470torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2471}
2472
2473// The access can be registerized internally to a condition, but must ensure
2474// that both initializer and finalizer are within the same condition.
2475TEST(Registerizer, RegisterizerNestedConditions) {2476BufHandle a("A", {1}, kInt);2477VarHandle x("x", kInt);2478StmtPtr stmt = Block::make({Cond::make(2479CompareSelect::make(x, 5, CompareSelectOperation::kLT),2480Block::make(2481{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2482Cond::make(2483CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2484Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2485nullptr)}),2486nullptr)});2487
2488/*2489* if (x<5 ? 1 : 0) {
2490* A[0] = (A[0]) + 1;
2491* if (x==2 ? 1 : 0) {
2492*
2493* A[0] = (A[0]) + 1;
2494* }
2495* }
2496*/
2497
2498stmt = registerize(stmt);2499
2500/*2501* if (x<5 ? 1 : 0) {
2502* int A_1 = A[0];
2503* A_1 = A_1 + 1;
2504* if (x==2 ? 1 : 0) {
2505* A_1 = A_1 + 1;
2506* }
2507* A[0] = A_1;
2508* }
2509*/
2510
2511std::ostringstream oss;2512oss << *stmt;2513
2514const std::string& verification_pattern =2515R"IR(2516# CHECK: if (x<5
2517# CHECK: int A_1 = A[0];
2518# CHECK: A_1 = A_1 + 1;
2519# CHECK: if (x==2
2520# CHECK: A_1 = A_1 + 1;
2521# CHECK: }
2522# CHECK: A[0] = A_1;
2523# CHECK: })IR";2524
2525torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2526}
2527
2528// If an access exists outside the scope of the condition then we can lift
2529// nested conditional usages into the same scalar.
2530TEST(Registerizer, RegisterizerNestedConditionsUnhidden) {2531BufHandle a("A", {1}, kInt);2532VarHandle x("x", kInt);2533StmtPtr stmt = Block::make(2534{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2535Cond::make(2536CompareSelect::make(x, 5, CompareSelectOperation::kLT),2537Block::make(2538{Store::make(a, {1}, 1),2539Cond::make(2540CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2541Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2542nullptr)}),2543nullptr)});2544
2545/*2546* A[0] = (A[0]) + 1;
2547* if (x<5 ? 1 : 0) {
2548* A[1] = 1;
2549* if (x==2 ? 1 : 0) {
2550* A[0] = (A[0]) + 1;
2551* }
2552* }
2553*/
2554
2555stmt = registerize(stmt);2556
2557/*2558* int A_1 = A[0];
2559* A_1 = A_1 + 1;
2560* if (x<5 ? 1 : 0) {
2561* A[1] = 1;
2562* if (x==2 ? 1 : 0) {
2563* A_1 = A_1 + 1;
2564* }
2565* }
2566* A[0] = A_1;
2567*/
2568
2569std::ostringstream oss;2570oss << *stmt;2571
2572const std::string& verification_pattern =2573R"IR(2574# CHECK: int A_1 = A[0];
2575# CHECK: A_1 = A_1 + 1;
2576# CHECK: if (x<5
2577# CHECK: A[1] = 1;
2578# CHECK: if (x==2
2579# CHECK: A_1 = A_1 + 1;
2580# CHECK: A[0] = A_1;)IR";2581
2582torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2583}
2584
2585TEST(Registerizer, RegisterizerNestedConditionsHiddenFirst) {2586BufHandle a("A", {1}, kInt);2587VarHandle x("x", kInt);2588StmtPtr stmt = Block::make(2589{Cond::make(2590CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2591Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2592nullptr),2593Cond::make(2594CompareSelect::make(x, 5, CompareSelectOperation::kLT),2595Block::make({Cond::make(2596CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2597Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2598nullptr)}),2599nullptr)});2600
2601/*2602* if (x==2 ? 1 : 0) {
2603* A[0] = (A[0]) + 1;
2604* }
2605* if (x<5 ? 1 : 0) {
2606* if (x==2 ? 1 : 0) {
2607* A[0] = (A[0]) + 1;
2608* }
2609* }
2610*/
2611
2612std::ostringstream before;2613before << *stmt;2614
2615// No change.2616stmt = registerize(stmt);2617
2618std::ostringstream after;2619after << *stmt;2620
2621ASSERT_EQ(before.str(), after.str());2622
2623// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)2624stmt = registerize(stmt);2625}
2626
2627TEST(Registerizer, RegisterizerNestedConditionsHiddenSecond) {2628BufHandle a("A", {1}, kInt);2629VarHandle x("x", kInt);2630StmtPtr stmt = Block::make(2631{Cond::make(2632CompareSelect::make(x, 5, CompareSelectOperation::kLT),2633Block::make({Cond::make(2634CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2635Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2636nullptr)}),2637nullptr),2638Cond::make(2639CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2640Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2641nullptr)});2642
2643/*2644* if (x<5 ? 1 : 0) {
2645* if (x==2 ? 1 : 0) {
2646* A[0] = (A[0]) + 1;
2647* }
2648* }
2649* if (x==2 ? 1 : 0) {
2650* A[0] = (A[0]) + 1;
2651* }
2652*/
2653
2654std::ostringstream before;2655before << *stmt;2656
2657// No change.2658stmt = registerize(stmt);2659
2660std::ostringstream after;2661after << *stmt;2662
2663ASSERT_EQ(before.str(), after.str());2664
2665// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)2666stmt = registerize(stmt);2667}
2668
2669// If an access is cut by another access internal to a condition block, it still
2670// cuts the access.
2671TEST(Registerizer, RegisterizerNestedConditionsCut) {2672BufHandle a("A", {1}, kInt);2673VarHandle x("x", kInt);2674StmtPtr stmt = Block::make(2675{Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2676Cond::make(2677CompareSelect::make(x, 5, CompareSelectOperation::kLT),2678Block::make(2679{Store::make(a, {x}, 1),2680Cond::make(2681CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2682Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2683nullptr)}),2684nullptr)});2685
2686/*2687* A[0] = (A[0]) + 1;
2688* if (x<5 ? 1 : 0) {
2689* A[x] = 1;
2690* if (x==2 ? 1 : 0) {
2691*
2692* A[0] = (A[0]) + 1;
2693* }
2694* }
2695*/
2696
2697std::ostringstream before;2698before << *stmt;2699
2700// No change.2701stmt = registerize(stmt);2702
2703std::ostringstream after;2704after << *stmt;2705
2706ASSERT_EQ(before.str(), after.str());2707}
2708
2709TEST(Registerizer, RegisterizerNestedConditionLoopHidden) {2710BufHandle a("A", {10}, kInt);2711BufHandle b("B", {10}, kInt);2712VarHandle x("x", kInt);2713StmtPtr stmt = Block::make(2714{Cond::make(2715CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2716Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2717nullptr),2718For::make(2719x,27200,272110,2722Block::make(2723{Store::make(b, {x}, 0),2724Cond::make(2725CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2726Store::make(a, {0}, Add::make(Load::make(a, {0}), 1)),2727nullptr)}))});2728
2729/*2730* if (x==2 ? 1 : 0) {
2731* A[0] = (A[0]) + 1;
2732* }
2733* for (int x = 0; x < 10; x++) {
2734* B[x] = 0; <-- this is only here to prevent Loop/Cond reordering.
2735* if (x==2 ? 1 : 0) {
2736* A[0] = (A[0]) + 1;
2737* }
2738* }
2739*/
2740
2741std::ostringstream before;2742before << *stmt;2743
2744// No change.2745stmt = registerize(stmt);2746
2747std::ostringstream after;2748after << *stmt;2749
2750ASSERT_EQ(before.str(), after.str());2751}
2752
2753// Three loops and four element regions, three of which should be registerized
2754// at different levels of the IR.
2755TEST(Registerizer, RegisterizerNestedConditionThreeDeep) {2756BufHandle a("A", {10}, kInt);2757BufHandle b("B", {10}, kInt);2758VarHandle x("x", kInt);2759StmtPtr stmt = Block::make(2760{Store::make(a, {4}, 0),2761Cond::make(2762CompareSelect::make(x, 2, CompareSelectOperation::kGT),2763Cond::make(2764CompareSelect::make(x, 3, CompareSelectOperation::kGT),2765Block::make({2766Cond::make(2767CompareSelect::make(x, 4, CompareSelectOperation::kGT),2768Block::make({2769Store::make(2770a, {1}, Add::make(Load::make(a, {1}), 1)),2771Store::make(2772a, {2}, Add::make(Load::make(a, {2}), 1)),2773Store::make(2774a, {3}, Add::make(Load::make(a, {3}), 1)),2775Store::make(2776a, {4}, Add::make(Load::make(a, {4}), 1)),2777Store::make(2778a, {1}, Add::make(Load::make(a, {1}), 1)),2779}),2780nullptr),2781Store::make(a, {2}, Add::make(Load::make(a, {2}), 1)),2782}),2783nullptr),2784nullptr)});2785
2786/*2787* A[4] = 0;
2788* if (x>2 ? 1 : 0) {
2789* if (x>3 ? 1 : 0) {
2790* if (x>4 ? 1 : 0) {
2791* A[1] = (A[1]) + 1;
2792* A[2] = (A[2]) + 1;
2793* A[3] = (A[3]) + 1;
2794* A[4] = (A[4]) + 1;
2795* A[1] = (A[1]) + 1;
2796* }
2797* A[2] = (A[2]) + 1;
2798* }
2799* }
2800*/
2801
2802stmt = registerize(stmt);2803
2804/*2805* int A_1 = 0;
2806* if (x>2 ? 1 : 0) {
2807* if (x>3 ? 1 : 0) {
2808* int A_3 = A[2];
2809* if (x>4 ? 1 : 0) {
2810* int A_2 = A[1];
2811* A_2 = A_2 + 1;
2812* A_3 = A_3 + 1;
2813* A[3] = (A[3]) + 1;
2814* A_1 = A_1 + 1;
2815* A_2 = A_2 + 1;
2816* A[1] = A_2;
2817* }
2818* A_3 = A_3 + 1;
2819* A[2] = A_3;
2820* }
2821* }
2822* A[4] = A_1;
2823*/
2824
2825std::ostringstream oss;2826oss << *stmt;2827
2828const std::string& verification_pattern =2829R"IR(2830# CHECK: int A_1 = 0;
2831# CHECK: if (x>2 ? 1 : 0) {
2832# CHECK: if (x>3 ? 1 : 0) {
2833# CHECK: int A_3 = A[2];
2834# CHECK: if (x>4 ? 1 : 0) {
2835# CHECK: int A_2 = A[1];
2836# CHECK: A_2 = A_2 + 1;
2837# CHECK: A_3 = A_3 + 1;
2838# CHECK: A[3] = (A[3]) + 1;
2839# CHECK: A_1 = A_1 + 1;
2840# CHECK: A_2 = A_2 + 1;
2841# CHECK: A[1] = A_2;
2842# CHECK: }
2843# CHECK: A_3 = A_3 + 1;
2844# CHECK: A[2] = A_3;
2845# CHECK: }
2846# CHECK: }
2847# CHECK: A[4] = A_1;)IR";2848
2849torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2850}
2851
2852// Can replace a simple scalar access with a local variable even when that
2853// variable is an outer loop var.
2854TEST(Registerizer, RegisterizerNestedLoopSimple) {2855BufHandle a("A", {1}, kInt);2856VarHandle x("x", kInt);2857VarHandle y("y", kInt);2858StmtPtr stmt = Block::make({For::make(2859y,28600,286110,2862For::make(2863x,28640,286510,2866Block::make(2867{Store::make(a, {y}, Add::make(Load::make(a, {y}), x))})))});2868
2869/*2870* for (int y = 0; y < 10; y++) {
2871* for (int x = 0; x < 10; x++) {
2872* A[y] = (A[y]) + x;
2873* }
2874* }
2875*/
2876
2877stmt = registerize(stmt);2878
2879/*2880* for (int y = 0; y < 10; y++) {
2881* int A_1 = A[y];
2882* for (int x = 0; x < 10; x++) {
2883* A_1 = A_1 + x;
2884* }
2885* A[y] = A_1;
2886* }
2887*/
2888
2889std::ostringstream oss;2890oss << *stmt;2891
2892const std::string& verification_pattern =2893R"IR(2894# CHECK: for (int y
2895# CHECK: int A_1 = A[y];
2896# CHECK: for (int x
2897# CHECK: A_1 = A_1 + x;
2898# CHECK: }
2899# CHECK: A[y] = A_1;
2900# CHECK: })IR";2901
2902torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2903}
2904
2905// Test the positive case of the hiddenAccess split, where an internal
2906// conditional access can be hoisted up through a loop to match an existing
2907// access in a higher scope and the two can be registerized.
2908TEST(Registerizer, RegisterizerHiddenAccessYes) {2909BufHandle a("A", {10}, kInt);2910BufHandle b("B", {10}, kInt);2911VarHandle x("x", kInt);2912VarHandle y("y", kInt);2913StmtPtr stmt = Block::make({Cond::make(2914CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2915Block::make(2916{Store::make(a, {0}, 0),2917For::make(2918x,29190,292010,2921Block::make(2922{Store::make(b, {x}, 0),2923// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)2924Cond::make(2925CompareSelect::make(x, 3, CompareSelectOperation::kEQ),2926For::make(2927y,29280,292910,2930Store::make(2931a, {0}, Add::make(Load::make(a, {0}), 1))),2932nullptr)}))}),2933nullptr)});2934
2935/*2936* if (x==2 ? 1 : 0) {
2937* A[0] = 0;
2938* for (int x = 0; x < 10; x++) {
2939* B[x] = 0;
2940* if (x==3 ? 1 : 0) {
2941* for (int y = 0; y < 10; y++) {
2942* A[0] = (A[0]) + 1;
2943* }
2944* }
2945* }
2946* }
2947*/
2948
2949stmt = registerize(stmt);2950
2951/*2952* if (x==2 ? 1 : 0) {
2953* int A_1 = 0;
2954* for (int x = 0; x < 10; x++) {
2955* B[x] = 0;
2956* if (x==3 ? 1 : 0) {
2957* for (int y = 0; y < 10; y++) {
2958* A_1 = A_1 + 1;
2959* }
2960* }
2961* }
2962* A[0] = A_1;
2963* }
2964*/
2965
2966std::ostringstream oss;2967oss << *stmt;2968
2969const std::string& verification_pattern =2970R"IR(2971# CHECK: if (x==2
2972# CHECK: int A_1 = 0;
2973# CHECK: for (int x
2974# CHECK: B[x] = 0;
2975# CHECK: if (x==3
2976# CHECK: for (int y
2977# CHECK: A_1 = A_1 + 1;
2978# CHECK: }
2979# CHECK: }
2980# CHECK: }
2981# CHECK: A[0] = A_1;
2982# CHECK: })IR";2983
2984torch::jit::testing::FileCheck().run(verification_pattern, oss.str());2985}
2986
2987// Test the negative case of the hiddenAccess split, where the hoisted access is
2988// never unhidden at a higher scope and registerization occurs at the lower
2989// scope.
2990TEST(Registerizer, RegisterizerHiddenAccessNo) {2991BufHandle a("A", {10}, kInt);2992BufHandle b("B", {10}, kInt);2993VarHandle x("x", kInt);2994VarHandle y("y", kInt);2995StmtPtr stmt = Block::make({Cond::make(2996CompareSelect::make(x, 2, CompareSelectOperation::kEQ),2997Block::make({For::make(2998x,29990,300010,3001Block::make(3002{Store::make(b, {x}, 0),3003// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)3004Cond::make(3005CompareSelect::make(x, 3, CompareSelectOperation::kEQ),3006For::make(3007y,30080,300910,3010Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),3011nullptr)}))}),3012nullptr)});3013
3014/*3015* if (x==2 ? 1 : 0) {
3016* A[0] = 0;
3017* for (int x = 0; x < 10; x++) {
3018* B[x] = 0;
3019* if (x==3 ? 1 : 0) {
3020* for (int y = 0; y < 10; y++) {
3021* A[0] = (A[0]) + 1;
3022* }
3023* }
3024* }
3025* }
3026*/
3027
3028stmt = registerize(stmt);3029
3030/*3031* if (x==2 ? 1 : 0) {
3032* for (int x = 0; x < 10; x++) {
3033* B[x] = 0;
3034* if (x==3 ? 1 : 0) {
3035* int A_1 = A[0];
3036* for (int y = 0; y < 10; y++) {
3037* A_1 = A_1 + 1;
3038* }
3039* A[0] = A_1;
3040* }
3041* }
3042* }
3043*/
3044
3045std::ostringstream oss;3046oss << *stmt;3047
3048const std::string& verification_pattern =3049R"IR(3050# CHECK: if (x==2
3051# CHECK: for (int x
3052# CHECK: B[x] = 0;
3053# CHECK: if (x==3
3054# CHECK: int A_1 = A[0];
3055# CHECK: for (int y
3056# CHECK: A_1 = A_1 + 1;
3057# CHECK: }
3058# CHECK: A[0] = A_1;
3059# CHECK: }
3060# CHECK: }
3061# CHECK: })IR";3062
3063torch::jit::testing::FileCheck().run(verification_pattern, oss.str());3064}
3065
3066// In this case the conditional access must be hoisted by two loops, there are
3067// two accesses here one is unhidden and the other isnt. A[0] can be
3068// registerized but B[0] cannot.
3069TEST(Registerizer, RegisterizerHiddenAccessMultiLoop) {3070BufHandle a("A", {10}, kInt);3071BufHandle b("B", {10}, kInt);3072VarHandle x("x", kInt);3073VarHandle y("y", kInt);3074StmtPtr stmt = Block::make({Cond::make(3075CompareSelect::make(x, 2, CompareSelectOperation::kEQ),3076Block::make(3077{Store::make(a, {0}, 0),3078For::make(3079x,30800,308110,3082For::make(3083y,30840,308510,3086Block::make({Cond::make(3087CompareSelect::make(y, 3, CompareSelectOperation::kEQ),3088Block::make(3089{Store::make(3090a, {0}, Add::make(Load::make(a, {0}), 1)),3091Store::make(3092b, {0}, Add::make(Load::make(b, {0}), 1))}),3093nullptr)})))}),3094nullptr)});3095
3096/*3097* if (x==2 ? 1 : 0) {
3098* A[0] = 0;
3099* for (int x = 0; x < 10; x++) {
3100* for (int y = 0; y < 10; y++) {
3101* if (y==3 ? 1 : 0) {
3102* A[0] = (A[0]) + 1;
3103* B[0] = (B[0]) + 1;
3104* }
3105* }
3106* }
3107* }
3108*/
3109
3110stmt = registerize(stmt);3111
3112/*3113* if (x==2 ? 1 : 0) {
3114* int A_1 = 0;
3115* for (int x = 0; x < 10; x++) {
3116* for (int y = 0; y < 10; y++) {
3117* if (y==3 ? 1 : 0) {
3118* A_1 = A_1 + 1;
3119* B[0] = (B[0]) + 1;
3120* }
3121* }
3122* }
3123* A[0] = A_1;
3124* }
3125*/
3126
3127std::ostringstream oss;3128oss << *stmt;3129
3130const std::string& verification_pattern =3131R"IR(3132# CHECK: if (x==2
3133# CHECK: int A_1 = 0;
3134# CHECK: for (int x
3135# CHECK: for (int y
3136# CHECK: if (y==3
3137# CHECK: A_1 = A_1 + 1;
3138# CHECK: B[0] = (B[0]) + 1;
3139# CHECK: }
3140# CHECK: }
3141# CHECK: }
3142# CHECK: A[0] = A_1;
3143# CHECK: })IR";3144
3145torch::jit::testing::FileCheck().run(verification_pattern, oss.str());3146}
3147
3148// Accesses are registerized inside two conditions, but the immediate parent is
3149// not a condition.
3150TEST(Registerizer, RegisterizerTwoConditionalLoops) {3151BufHandle a("A", {1}, kInt);3152VarHandle x("x", kInt);3153StmtPtr stmt = Block::make(3154{Cond::make(3155CompareSelect::make(x, 5, CompareSelectOperation::kLT),3156For::make(3157x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),3158nullptr),3159Cond::make(3160CompareSelect::make(x, 5, CompareSelectOperation::kGT),3161For::make(3162x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),3163nullptr)});3164
3165/*3166* if (x<5 ? 1 : 0) {
3167* for (int x = 0; x < 10; x++) {
3168* A[0] = (A[0]) + 1;
3169* }
3170* }
3171* if (x>5 ? 1 : 0) {
3172* for (int x = 0; x < 10; x++) {
3173* A[0] = (A[0]) + 1;
3174* }
3175* }
3176*/
3177
3178stmt = registerize(stmt);3179
3180/*3181* if (x<5 ? 1 : 0) {
3182* int A_1 = A[0];
3183* for (int x = 0; x < 10; x++) {
3184* A_1 = A_1 + 1;
3185* }
3186* A[0] = A_1;
3187* }
3188* if (x>5 ? 1 : 0) {
3189* int A_2 = A[0];
3190* for (int x = 0; x < 10; x++) {
3191* A_2 = A_2 + 1;
3192* }
3193* A[0] = A_2;
3194* }
3195*/
3196
3197std::ostringstream oss;3198oss << *stmt;3199
3200const std::string& verification_pattern =3201R"IR(3202# CHECK: if (x<5
3203# CHECK: int A_1 = A[0];
3204# CHECK: for (int x
3205# CHECK: A_1 = A_1 + 1;
3206# CHECK: }
3207# CHECK: A[0] = A_1;
3208# CHECK: }
3209# CHECK: if (x>5
3210# CHECK: int A_2 = A[0];
3211# CHECK: for (int x
3212# CHECK: A_2 = A_2 + 1;
3213# CHECK: }
3214# CHECK: A[0] = A_2;
3215# CHECK: })IR";3216
3217torch::jit::testing::FileCheck().run(verification_pattern, oss.str());3218}
3219
3220// Accesses are registerized inside two conditions, cut in the middle.
3221TEST(Registerizer, RegisterizerTwoConditionalLoopsCut) {3222BufHandle a("A", {1}, kInt);3223VarHandle x("x", kInt);3224StmtPtr stmt = Block::make(3225{Cond::make(3226CompareSelect::make(x, 5, CompareSelectOperation::kLT),3227For::make(3228x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),3229nullptr),3230For::make(x, 0, 10, Store::make(a, {x}, 1)),3231Cond::make(3232CompareSelect::make(x, 5, CompareSelectOperation::kGT),3233For::make(3234x, 0, 10, Store::make(a, {0}, Add::make(Load::make(a, {0}), 1))),3235nullptr)});3236
3237/*3238* if (x<5 ? 1 : 0) {
3239* for (int x = 0; x < 10; x++) {
3240* A[0] = (A[0]) + 1;
3241* }
3242* }
3243* for (int x = 0; x < 10; x++) {
3244* A[x] = 1;
3245* }
3246* if (x>5 ? 1 : 0) {
3247* for (int x = 0; x < 10; x++) {
3248* A[0] = (A[0]) + 1;
3249* }
3250* }
3251*/
3252
3253stmt = registerize(stmt);3254
3255/*3256* if (x<5 ? 1 : 0) {
3257* int A_1 = A[0];
3258* for (int x = 0; x < 10; x++) {
3259* A_1 = A_1 + 1;
3260* }
3261* A[0] = A_1;
3262* }
3263* for (int x = 0; x < 10; x++) {
3264* A[x] = 1;
3265* }
3266* if (x>5 ? 1 : 0) {
3267* int A_2 = A[0];
3268* for (int x = 0; x < 10; x++) {
3269* A_2 = A_2 + 1;
3270* }
3271* A[0] = A_2;
3272* }
3273*/
3274
3275std::ostringstream oss;3276oss << *stmt;3277
3278const std::string& verification_pattern =3279R"IR(3280# CHECK: if (x<5
3281# CHECK: int A_1 = A[0];
3282# CHECK: for (int x
3283# CHECK: A_1 = A_1 + 1;
3284# CHECK: }
3285# CHECK: A[0] = A_1;
3286# CHECK: }
3287# CHECK: for (int x
3288# CHECK: A[x] = 1;
3289# CHECK: if (x>5
3290# CHECK: int A_2 = A[0];
3291# CHECK: for (int x
3292# CHECK: A_2 = A_2 + 1;
3293# CHECK: }
3294# CHECK: A[0] = A_2;
3295# CHECK: })IR";3296
3297torch::jit::testing::FileCheck().run(verification_pattern, oss.str());3298}
3299
3300// references a Let var in a local scope which cannot be hoisted out of the
3301// loop.
3302TEST(Registerizer, RegisterizerLoopLetVar) {3303BufHandle a("A", {10}, kInt);3304VarHandle x("x", kInt);3305VarHandle y("y", kInt);3306StmtPtr stmt = IRSimplifier::simplify(Block::make({For::make(3307x,33080,330910,3310Block::make(3311{Let::make(y, 30),3312Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))}));3313
3314/*3315* for (int x = 0; x < 10; x++) {
3316* int y = 30;
3317* A[y] = x + (A[y]);
3318* }
3319*/
3320
3321std::ostringstream before;3322before << *stmt;3323
3324// No change.3325stmt = registerize(stmt);3326
3327std::ostringstream after;3328after << *stmt;3329
3330ASSERT_EQ(before.str(), after.str());3331}
3332
3333// references a Let var in an outer scope that does not prevent hoisting the
3334// initializer.
3335TEST(Registerizer, RegisterizerLoopLetVarOuter) {3336BufHandle a("A", {10}, kInt);3337VarHandle x("x", kInt);3338VarHandle y("y", kInt);3339StmtPtr stmt = Block::make(3340{Let::make(y, 30),3341For::make(3342x,33430,334410,3345Block::make(3346{Store::make(a, {y}, Add::make(x, Load::make(a, {y})))}))});3347
3348/*3349* int y = 30;
3350* for (int x = 0; x < 10; x++) {
3351* A[y] = x + (A[y]);
3352* }
3353*/
3354
3355stmt = registerize(stmt);3356
3357/*3358* int y = 30;
3359* int A_1 = A[y];
3360* for (int x = 0; x < 10; x++) {
3361* A_1 = A_1 + x;
3362* }
3363* A[y] = A_1;
3364*/
3365
3366std::ostringstream oss;3367oss << *stmt;3368
3369const std::string& verification_pattern =3370R"IR(3371# CHECK: int y = 30;
3372# CHECK: int A_1 = A[y];
3373# CHECK: for (int x
3374# CHECK: A_1 = A_1 + x;
3375# CHECK: A[y] = A_1;)IR";3376
3377torch::jit::testing::FileCheck().run(verification_pattern, oss.str());3378}
3379
3380// Okay so the registerizer generally goes after index flattening, but just in
3381// case. Test multi index registerization.
3382TEST(Registerizer, RegisterizerMultiDim) {3383BufHandle a("A", {3, 4, 5}, kInt);3384VarHandle x("x", kInt);3385StmtPtr stmt = Block::make(3386{Store::make(a, {0, 1, 2}, 0),3387For::make(3388x,33890,339010,3391Block::make({Store::make(3392a, {0, 1, 2}, Add::make(Load::make(a, {0, 1, 2}), x))}))});3393
3394/*3395* A[0, 1, 2] = 0;
3396* for (int x = 0; x < 10; x++) {
3397* A[0, 1, 2] = (A[0, 1, 2]) + x;
3398* }
3399*/
3400
3401stmt = registerize(stmt);3402
3403/*3404* int A_1 = 0;
3405* for (int x = 0; x < 10; x++) {
3406* A_1 = x + A_1;
3407* }
3408* A[0, 1, 2] = A_1;
3409*/
3410
3411std::ostringstream oss;3412oss << *stmt;3413
3414const std::string& verification_pattern =3415R"IR(3416# CHECK: int A_1 = 0;
3417# CHECK: for (int x = 0; x < 10; x++)
3418# CHECK-NOT: A[
3419# CHECK: A_1 =
3420# CHECK: A[0, 1, 2] = A_1;)IR";3421
3422torch::jit::testing::FileCheck().run(verification_pattern, oss.str());3423}
3424
3425// Wont registerize if only some dims match, but will still registerize distinct
3426// elements.
3427TEST(Registerizer, RegisterizerMultiDimPartial) {3428BufHandle a("A", {3, 4, 5}, kInt);3429VarHandle x("x", kInt);3430StmtPtr stmt = Block::make(3431{Store::make(a, {0, 1, 2}, 0),3432For::make(3433x,34340,343510,3436Block::make({Store::make(3437a, {0, 2, 2}, Add::make(Load::make(a, {0, 1, 4}), x))}))});3438
3439/*3440* A[0, 1, 2] = 0;
3441* for (int x = 0; x < 10; x++) {
3442* A[0, 2, 2] = (A[0, 1, 4]) + x;
3443* }
3444*/
3445
3446stmt = registerize(stmt);3447
3448/*3449* A[0, 1, 2] = 0;
3450* int A_1 = A[0, 1, 4];
3451* int A_2 = A[0, 2, 2];
3452* for (int x = 0; x < 10; x++) {
3453* A_2 = A_1 + x;
3454* }
3455* A[0, 2, 2] = A_2;
3456*/
3457
3458std::ostringstream oss;3459oss << *stmt;3460
3461const std::string& verification_pattern =3462R"IR(3463# CHECK: A[0, 1, 2] = 0;
3464# CHECK: int A_1 = A[0, 1, 4];
3465# CHECK: int A_2 = A[0, 2, 2];
3466# CHECK: for (
3467# CHECK: A_2 = A_1 + x;
3468# CHECK: A[0, 2, 2] = A_2;)IR";3469
3470torch::jit::testing::FileCheck().run(verification_pattern, oss.str());3471}
3472
3473// If they could overlap across all dimensions we cannot registerize.
3474TEST(Registerizer, RegisterizerMultiDimOverlap) {3475BufHandle a("A", {3, 4, 5}, kInt);3476VarHandle x("x", kInt);3477VarHandle y("y", kInt);3478StmtPtr stmt = Block::make(3479{Store::make(a, {0, 1, 2}, 0),3480For::make(3481x,34820,348310,3484Block::make({Store::make(3485a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 2}), x))}))});3486stmt = IRSimplifier::simplify(stmt);3487
3488/*3489* A[0, 1, 2] = 0;
3490* for (int x = 0; x < 10; x++) {
3491* A[0, x, 2] = (A[y, 2, 2]) + x;
3492* }
3493*/
3494
3495std::ostringstream before;3496before << *stmt;3497
3498// No change.3499stmt = registerize(stmt);3500
3501std::ostringstream after;3502after << *stmt;3503
3504ASSERT_EQ(before.str(), after.str());3505}
3506
3507// But, if one dimension is known to be distinct they do not overlap.
3508TEST(Registerizer, RegisterizerMultiDimPartialOverlap) {3509BufHandle a("A", {3, 4, 5}, kInt);3510VarHandle x("x", kInt);3511VarHandle y("y", kInt);3512StmtPtr stmt = Block::make(3513{Store::make(a, {0, 1, 2}, 0),3514For::make(3515x,35160,351710,3518Block::make({Store::make(3519a, {0, x, 2}, Add::make(Load::make(a, {y, 2, 4}), x))}))});3520
3521/*3522* A[0, 1, 2] = 0; <---- 2nd dim overlaps with store.
3523* for (int x = 0; x < 10; x++) {
3524* A[0, x, 2] = (A[y, 2, 4]) + x; <---- 3rd dim has constant diff.
3525* }
3526*/
3527
3528stmt = registerize(stmt);3529
3530/*3531* A[0, 1, 2] = 0;
3532* int A_1 = A[y, 2, 4];
3533* for (int x = 0; x < 10; x++) {
3534* A[0, x, 2] = A_1 + x;
3535* }
3536*/
3537
3538std::ostringstream oss;3539oss << *stmt;3540
3541const std::string& verification_pattern =3542R"IR(3543# CHECK: A[0, 1, 2] = 0;
3544# CHECK: int A_1 = A[y, 2, 4];
3545# CHECK: for (
3546# CHECK: A[0, x, 2] = A_1 + x;
3547# CHECK: })IR";3548
3549torch::jit::testing::FileCheck().run(verification_pattern, oss.str());3550}
3551
3552// A 3D reduction with different input dimensionality.
3553TEST(Registerizer, RegisterizerMultiDim3DReduction1) {3554BufHandle a("A", {10}, kInt);3555BufHandle b("B", {10, 10}, kInt);3556BufHandle c("C", {10, 10, 10}, kInt);3557VarHandle x("x", kInt);3558VarHandle y("y", kInt);3559VarHandle z("z", kInt);3560StmtPtr stmt = For::make(3561x,35620,356310,3564For::make(3565y,35660,356710,3568For::make(3569z,35700,357110,3572Store::make(3573c,3574{x, y, z},3575Add::make(3576Load::make(c, {x, y, z}),3577Mul::make(Load::make(b, {x, y}), Load::make(a, {x})))))));3578
3579/*3580* for (int x = 0; x < 10; x++) {
3581* for (int y = 0; y < 10; y++) {
3582* for (int z = 0; z < 10; z++) {
3583* C[x, y, z] = (C[x, y, z]) + (B[x, y]) * (A[x]);
3584* }
3585* }
3586* }
3587*/
3588
3589// We can registerize the A and B access since they can be hoisted before3590// hitting a dependent loop var.3591
3592stmt = registerize(stmt);3593
3594/*3595* for (int x = 0; x < 10; x++) {
3596* int A_1 = A[x];
3597* for (int y = 0; y < 10; y++) {
3598* int B_1 = B[x, y];
3599* for (int z = 0; z < 10; z++) {
3600* C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
3601* }
3602* }
3603* }
3604*/
3605
3606std::ostringstream oss;3607oss << *stmt;3608
3609const std::string& verification_pattern =3610R"IR(3611# CHECK: for (int x
3612# CHECK: int A_1 = A[x];
3613# CHECK: for (int y
3614# CHECK: int B_1 = B[x, y];
3615# CHECK: for (int z
3616# CHECK: C[x, y, z] = A_1 * B_1 + (C[x, y, z]);
3617# CHECK: })IR";3618
3619torch::jit::testing::FileCheck().run(verification_pattern, oss.str());3620}
3621
3622// A 3D reduction with the same smaller dimensionality using different loop
3623// vars.
3624TEST(Registerizer, RegisterizerMultiDim3DReduction2) {3625BufHandle a("A", {10}, kInt);3626BufHandle b("B", {10}, kInt);3627BufHandle c("C", {10}, kInt);3628VarHandle x("x", kInt);3629VarHandle y("y", kInt);3630VarHandle z("z", kInt);3631StmtPtr stmt = For::make(3632x,36330,363410,3635// NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)3636For::make(3637y,36380,363910,3640For::make(3641z,36420,364310,3644Store::make(3645c,3646{x},3647Add::make(3648Load::make(c, {x}),3649Mul::make(Load::make(b, {y}), Load::make(a, {x})))))));3650
3651/*3652* for (int x = 0; x < 10; x++) {
3653* for (int y = 0; y < 10; y++) {
3654* for (int z = 0; z < 10; z++) {
3655* C[x] = (C[x]) + (B[y]) * (A[x]);
3656* }
3657* }
3658* }
3659*/
3660
3661// We can registerize all accesses, the A and C access can be hoisted to the3662// outer loop since they depend only on it's loop var while the B can only be3663// raised to the loop of y.3664
3665stmt = registerize(stmt);3666
3667/*3668* for (int x = 0; x < 10; x++) {
3669* int A_1 = A[x];
3670* int C_1 = C[x];
3671* for (int y = 0; y < 10; y++) {
3672* int B_1 = B[y];
3673* for (int z = 0; z < 10; z++) {
3674* C_1 = A_1 * B_1 + C_1;
3675* }
3676* }
3677* C[x] = C_1;
3678* }
3679*/
3680
3681std::ostringstream oss;3682oss << *stmt;3683
3684const std::string& verification_pattern =3685R"IR(3686# CHECK: for (int x
3687# CHECK: int A_1 = A[x];
3688# CHECK: int C_1 = C[x];
3689# CHECK: for (int y
3690# CHECK: int B_1 = B[y];
3691# CHECK: for (int z
3692# CHECK: C_1 = A_1 * B_1 + C_1;
3693# CHECK: }
3694# CHECK: }
3695# CHECK: C[x] = C_1;
3696# CHECK: })IR";3697
3698torch::jit::testing::FileCheck().run(verification_pattern, oss.str());3699}
3700
3701} // namespace jit3702} // namespace torch3703