pytorch
1934 строки · 51.0 Кб
1#include <gtest/gtest.h>
2
3#include <limits>
4#include <memory>
5#include <sstream>
6#include <stdexcept>
7#include <unordered_map>
8
9#include <test/cpp/tensorexpr/test_base.h>
10
11#include <c10/util/irange.h>
12#include <test/cpp/tensorexpr/padded_buffer.h>
13#include <torch/csrc/jit/tensorexpr/analysis.h>
14#include <torch/csrc/jit/tensorexpr/eval.h>
15#include <torch/csrc/jit/tensorexpr/ir.h>
16#include <torch/csrc/jit/tensorexpr/ir_printer.h>
17#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
18#include <torch/csrc/jit/tensorexpr/loopnest.h>
19#include <torch/csrc/jit/tensorexpr/tensor.h>
20#include <torch/csrc/jit/testing/file_check.h>
21
22namespace torch {
23namespace jit {
24
25using namespace torch::jit::tensorexpr;
26
27TEST(Reductions, ReduceSum0D_1) {
28const int M = 10;
29
30BufHandle b("b", {M}, kFloat);
31std::vector<float> in(M);
32for (const auto j : c10::irange(M)) {
33in[j] = j;
34}
35
36std::vector<float> out(M, -1.f);
37
38Tensor c = Reduce("sum", {M}, Sum(), b, {});
39LoopNest loop({c});
40loop.prepareForCodegen();
41StmtPtr s = loop.root_stmt();
42s = IRSimplifier::simplify(s);
43
44SimpleIREvaluator cg(s, {b, c});
45
46cg.call({in, out});
47for (const auto i : c10::irange(M)) {
48ASSERT_EQ(out[i], in[i]);
49}
50}
51
52TEST(Reductions, ReduceSum0D_2) {
53BufHandle b("b", {}, kFloat);
54std::vector<float> in(1);
55in[0] = 77.7;
56
57std::vector<float> out(1, -1.f);
58
59Tensor c = Reduce("sum", {}, Sum(), b, {});
60LoopNest loop({c});
61loop.prepareForCodegen();
62StmtPtr s = loop.root_stmt();
63s = IRSimplifier::simplify(s);
64
65SimpleIREvaluator cg(s, {b, c});
66
67cg.call({in, out});
68ASSERT_EQ(out[0], in[0]);
69}
70
71// Sum an array to a single value.
72TEST(Reductions, ReduceSum1D) {
73BufHandle b("b", {10}, kFloat);
74std::vector<float> in(10);
75for (const auto j : c10::irange(10)) {
76in[j] = j;
77}
78
79std::vector<float> out(1, -1.f);
80
81Tensor c = Reduce("sum", {}, Sum(), b, {10});
82LoopNest loop({c});
83loop.prepareForCodegen();
84StmtPtr s = loop.root_stmt();
85s = IRSimplifier::simplify(s);
86
87SimpleIREvaluator cg(s, {b, c});
88
89cg.call({in, out});
90ASSERT_EQ(out[0], 45);
91}
92// Sum a 2D tensor to a 1D tensor with dynamic shapes.
93TEST(Reductions, ReduceSum2D) {
94const int M = 3;
95const int N = 7;
96
97VarHandle m("m", kInt);
98VarHandle n("n", kInt);
99
100BufHandle b("b", {m, n}, kFloat);
101std::vector<float> in(M * N);
102for (const auto i : c10::irange(M)) {
103for (const auto j : c10::irange(N)) {
104in[i * N + j] = j;
105}
106}
107
108std::vector<float> out(M, -1.f);
109
110Tensor c = Reduce("sum", {M}, Sum(), b, {N});
111LoopNest loop({c});
112loop.prepareForCodegen();
113StmtPtr s = loop.root_stmt();
114s = IRSimplifier::simplify(s);
115
116SimpleIREvaluator cg(s, {b, c, n, m});
117
118cg.call({in, out, 5, 7});
119
120float expected = 0;
121for (const auto i : c10::irange(N)) {
122// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
123expected += i;
124}
125
126for (const auto i : c10::irange(M)) {
127ASSERT_EQ(out[i], expected);
128}
129}
130
131// Sum a 3D tensor to both a 2D and 1D tensor, then reduce the 2D tensor flat to
132// check our work.
133TEST(Reductions, ReduceSum3D) {
134const int M = 10;
135VarHandle m("m", kInt);
136
137BufHandle b("b", {2, 3, m}, kFloat);
138
139Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
140LoopNest loop({c});
141loop.prepareForCodegen();
142StmtPtr s = loop.root_stmt();
143s = IRSimplifier::simplify(s);
144
145SimpleIREvaluator cg(s, {b, c, m});
146
147std::vector<float> bData(2 * 3 * M, 0);
148std::vector<float> cData(2 * 3, 6.0f);
149std::vector<float> dData(2, 1.0f);
150std::vector<float> eData(2, 1.0f);
151
152for (int i = 0; i < 2 * 3; ++i) {
153for (const auto j : c10::irange(M)) {
154bData[i * M + j] = j;
155}
156}
157
158cg.call({bData, cData, M});
159float expected = 0;
160for (const auto i : c10::irange(M)) {
161// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
162expected += i;
163}
164
165for (int i = 0; i < 2 * 3; ++i) {
166ASSERT_EQ(cData[i], expected);
167}
168
169Tensor d = Reduce("sum2", {2}, Sum(), b, {3, m});
170LoopNest loop2({d});
171loop2.prepareForCodegen();
172StmtPtr s2 = loop2.root_stmt();
173s2 = IRSimplifier::simplify(s2);
174
175SimpleIREvaluator cg2(s2, {b, d, m});
176cg2.call({bData, dData, M});
177
178// We're combining an additional dimension of 3, so the sum is 3x.
179expected = expected * 3;
180
181for (const auto i : c10::irange(2)) {
182ASSERT_EQ(dData[i], expected);
183}
184
185// This is the same as just reducing the original result across that axis.
186BufHandle c_buf(c.buf());
187Tensor e = Reduce("sum3", {2}, Sum(), c_buf, {3});
188LoopNest loop3({e});
189loop3.prepareForCodegen();
190StmtPtr s3 = loop3.root_stmt();
191s3 = IRSimplifier::simplify(s3);
192
193SimpleIREvaluator cg3(s3, {c, e});
194cg3.call({cData, eData});
195
196for (const auto i : c10::irange(2)) {
197ASSERT_EQ(eData[i], expected);
198}
199}
200
201// Sum a large (10 D) Tensor 5 dimensions in.
202TEST(Reductions, ReduceSum10D) {
203BufHandle in_("in_", {2, 3, 2, 3, 2, 3, 2, 3, 2, 3}, kFloat);
204const int InputSize = 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3 * 2 * 3;
205BufHandle out_("out_", {2, 3, 2, 3, 2}, kFloat);
206const int OutputSize = 2 * 3 * 2 * 3 * 2;
207
208std::vector<float> in(InputSize, 1.f);
209std::vector<float> out(OutputSize, -1.f);
210
211Tensor c = Reduce("sum", {2, 3, 2, 3, 2}, Sum(), in_, {3, 2, 3, 2, 3});
212LoopNest loop({c});
213loop.prepareForCodegen();
214StmtPtr s = loop.root_stmt();
215s = IRSimplifier::simplify(s);
216
217SimpleIREvaluator cg(s, {in_, c});
218
219cg.call({in, out});
220
221// NOLINTNEXTLINE(bugprone-integer-division)
222float expected = InputSize / OutputSize;
223for (const auto i : c10::irange(OutputSize)) {
224ASSERT_EQ(out[i], expected);
225}
226}
227
228// Reduce via Mul rather than Add using a custom Reducer.
229TEST(Reductions, ReduceProduct) {
230const int M = 4;
231const int N = 4;
232
233BufHandle b("b", {M, N}, kFloat);
234std::vector<float> in(M * N);
235for (const auto i : c10::irange(M)) {
236for (const auto j : c10::irange(N)) {
237in[i * N + j] = 2 + j;
238}
239}
240
241std::vector<float> out(M, -1.f);
242
243Reducer product(
244ExprHandle(1.f), [](ExprHandle a, ExprHandle b) { return a * b; });
245
246Tensor c = Reduce("product", {M}, product, b, {N});
247LoopNest loop({c});
248loop.prepareForCodegen();
249StmtPtr s = loop.root_stmt();
250s = IRSimplifier::simplify(s);
251
252SimpleIREvaluator cg(s, {b, c});
253
254cg.call({in, out});
255
256float expected = 1;
257for (const auto i : c10::irange(N)) {
258// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
259expected *= 2 + i;
260}
261
262for (const auto i : c10::irange(M)) {
263ASSERT_EQ(out[i], expected);
264}
265}
266
267// Maximum reductions.
268TEST(Reductions, ReduceMax) {
269BufHandle in_("b", {10}, kFloat);
270
271std::vector<float> in(10);
272std::vector<float> out(1, -1.f);
273for (const auto j : c10::irange(10)) {
274in[j] = j;
275}
276
277Tensor dm1 = Reduce("max", {}, Maximum(kFloat), in_, {10});
278
279LoopNest loop({dm1});
280loop.prepareForCodegen();
281StmtPtr s = loop.root_stmt();
282s = IRSimplifier::simplify(s);
283SimpleIREvaluator cg(s, {in_, dm1});
284
285cg.call({in, out});
286
287ASSERT_EQ(out[0], 9);
288
289BufHandle in2_("b", {2, 5}, kFloat);
290std::vector<float> out2(2, -1.f);
291
292Tensor m2d = Reduce("max", {2}, Maximum(kFloat), in2_, {5});
293
294LoopNest loop2({m2d});
295loop2.prepareForCodegen();
296s = loop2.root_stmt();
297s = IRSimplifier::simplify(s);
298
299SimpleIREvaluator cg2(s, {in2_, m2d});
300cg2.call({in, out2});
301
302ASSERT_EQ(out2[0], 4);
303ASSERT_EQ(out2[1], 9);
304}
305
306// Minimum reduction, with custom initialization.
307TEST(Reductions, ReduceMinCustomInitializer) {
308VarHandle minInit("minInit", kFloat);
309BufHandle in_("b", {10}, kFloat);
310
311std::vector<float> in(10);
312std::vector<float> out(1, -1.f);
313for (const auto j : c10::irange(10)) {
314in[j] = 10 + j;
315}
316
317Tensor min = Reduce(
318"min",
319{},
320Minimum(ExprHandle(minInit)),
321[&](ParameterList& v) { return in_.load(v); },
322{10});
323
324LoopNest loop({min});
325loop.prepareForCodegen();
326StmtPtr s = loop.root_stmt();
327s = IRSimplifier::simplify(s);
328
329SimpleIREvaluator cg(s, {in_, min, minInit});
330
331// Works normally (note that out data starts lower than the correct
332// minimum).
333cg.call({in, out, std::numeric_limits<float>::max()});
334ASSERT_EQ(out[0], 10);
335
336// With an initalizer lower than the min, that's the min.
337cg.call({in, out, 5.f});
338ASSERT_EQ(out[0], 5);
339}
340
341// Example implementation of Any/All.
342// TODO: this is very awkward without logical And/Or operators.
343TEST(Reductions, ReduceAnyAll) {
344VarHandle searchValue("searchValue", kInt);
345BufHandle b("b", {4, 10}, kInt);
346
347Reducer anyEqSV(ExprHandle(0), [](ExprHandle a, ExprHandle b) {
348return CompareSelect::make(a, 1, 1, b, kEQ);
349});
350
351Tensor any = Reduce(
352"anyEqual",
353{4},
354anyEqSV,
355[&](const auto& i, const auto& j) {
356return CompareSelect::make(b.load(i, j), searchValue, kEQ);
357},
358{10});
359
360LoopNest loop({any});
361loop.prepareForCodegen();
362StmtPtr s = loop.root_stmt();
363s = IRSimplifier::simplify(s);
364
365SimpleIREvaluator cg(s, {b, any, searchValue});
366
367std::vector<int> in(40, 0);
368std::vector<int> out(4, 0);
369
370// input has 0-39 in 4 rows.
371for (const auto i : c10::irange(40)) {
372in[i] = i;
373}
374cg.call({in, out, 1});
375
376// only the first row has 1
377ASSERT_EQ(out[0], 1);
378ASSERT_EQ(out[1], 0);
379ASSERT_EQ(out[2], 0);
380ASSERT_EQ(out[3], 0);
381
382cg.call({in, out, 15});
383
384// 15 in the 3rd row
385ASSERT_EQ(out[0], 0);
386ASSERT_EQ(out[1], 1);
387ASSERT_EQ(out[2], 0);
388ASSERT_EQ(out[3], 0);
389
390Reducer allGTSV(ExprHandle(1), [](ExprHandle a, ExprHandle b) {
391return CompareSelect::make(a, 0, 0, b, kEQ);
392});
393
394Tensor allGreaterThan = Reduce(
395"allGreaterThan",
396{4},
397allGTSV,
398[&](const auto& i, const auto& j) {
399return CompareSelect::make(b.load(i, j), searchValue, kGT);
400},
401{10});
402
403LoopNest loop2({allGreaterThan});
404loop2.prepareForCodegen();
405s = loop2.root_stmt();
406s = IRSimplifier::simplify(s);
407
408SimpleIREvaluator cg2(s, {b, allGreaterThan, searchValue});
409
410cg2.call({in, out, 11});
411
412// 11 is in row 2.
413ASSERT_EQ(out[0], 0);
414ASSERT_EQ(out[1], 0);
415ASSERT_EQ(out[2], 1);
416ASSERT_EQ(out[3], 1);
417
418cg2.call({in, out, -3});
419
420// All are positive.
421ASSERT_EQ(out[0], 1);
422ASSERT_EQ(out[1], 1);
423ASSERT_EQ(out[2], 1);
424ASSERT_EQ(out[3], 1);
425}
426
427TEST(Reductions, ReduceMatmul2D) {
428BufHandle tA("tA", {3, 2}, kFloat);
429BufHandle tB("tB", {2, 3}, kFloat);
430
431std::vector<float> tA_(6);
432std::vector<float> tB_(6);
433
434std::vector<float> out(9, -1.f);
435for (const auto i : c10::irange(3)) {
436for (const auto j : c10::irange(2)) {
437tA_[i * 2 + j] = i * 2 + j;
438tB_[j * 3 + i] = i * 2 + j;
439}
440}
441
442Tensor mm = Reduce(
443"mm",
444{3, 3},
445Sum(),
446[&](const ExprHandle& m, const ExprHandle& n, const ExprHandle& k) {
447return tA.load(m, k) * tB.load(k, n);
448},
449{2});
450
451LoopNest loop({mm});
452loop.prepareForCodegen();
453StmtPtr s = loop.root_stmt();
454s = IRSimplifier::simplify(s);
455
456SimpleIREvaluator cg(s, {tA, tB, mm});
457cg.call({tA_, tB_, out});
458
459std::vector<float> expected(
460{1.f, 3.f, 5.f, 3.f, 13.f, 23.f, 5.f, 23.f, 41.f});
461
462for (const auto i : c10::irange(9)) {
463ASSERT_EQ(out[i], expected[i]);
464}
465}
466
467TEST(Reductions, ReduceRfactorLike) {
468BufHandle in("in", {10, 10}, kFloat);
469std::vector<float> in_(100);
470for (const auto i : c10::irange(100)) {
471in_[i] = i;
472}
473std::vector<float> in_rf_(10, -2.f);
474std::vector<float> out(1, -1.f);
475
476Tensor l1 = Reduce("l1", {10}, Sum(), in, {10});
477BufHandle in_rf(l1.buf());
478
479Tensor l2 = Reduce("l2", {}, Sum(), in_rf, {10});
480
481LoopNest loop({l1, l2});
482loop.prepareForCodegen();
483StmtPtr s = loop.root_stmt();
484s = IRSimplifier::simplify(s);
485
486SimpleIREvaluator cg(s, {in, l1, l2});
487cg.call({in_, in_rf_, out});
488
489ASSERT_EQ(out[0], 99 * 50);
490}
491
492TEST(Reductions, ReduceAsProducer) {
493const int M = 10;
494VarHandle m("m", kInt);
495
496BufHandle a("a", {2, 3}, kFloat);
497BufHandle b("b", {2, 3, m}, kFloat);
498
499Tensor c = Reduce("sum", {2, 3}, Sum(), b, {m});
500Tensor d =
501Compute("scale", {2, 3}, [&](const VarHandle& l, const VarHandle& n) {
502return c.load(l, n) * a.load(l, n);
503});
504LoopNest loop({d}, {c, d});
505loop.prepareForCodegen();
506StmtPtr s = loop.root_stmt();
507s = IRSimplifier::simplify(s);
508
509SimpleIREvaluator cg(s, {a, b, d, m});
510
511std::vector<float> aData(2 * 3, 0);
512std::vector<float> bData(2 * 3 * M, 0);
513std::vector<float> dData(2 * 3, 6.0f);
514
515for (int i = 0; i < 2 * 3; ++i) {
516aData[i] = 6 - i;
517for (const auto j : c10::irange(M)) {
518bData[i * M + j] = j;
519}
520}
521
522cg.call({aData, bData, dData, M});
523float expected = 0;
524for (const auto i : c10::irange(M)) {
525// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
526expected += i;
527}
528for (int i = 0; i < 2 * 3; ++i) {
529ASSERT_EQ(dData[i], expected * (6 - i));
530}
531}
532
533TEST(Reductions, ReduceAsConsumer) {
534const int M = 10;
535VarHandle m("m", kInt);
536
537BufHandle a("a", {2, 3, m}, kFloat);
538BufHandle b("b", {2, 3, m}, kFloat);
539
540Tensor c = Compute(
541"scale",
542{2, 3, m},
543[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
544return b.load(l, n, m) * a.load(l, n, m);
545});
546Tensor d = Reduce("sum", {2}, Sum(), c, {3, m});
547LoopNest loop({d}, {c, d});
548loop.prepareForCodegen();
549StmtPtr s = loop.root_stmt();
550s = IRSimplifier::simplify(s);
551
552SimpleIREvaluator cg(s, {a, b, d, m});
553
554std::vector<float> aData(2 * 3 * M, 0);
555std::vector<float> bData(2 * 3 * M, 0);
556std::vector<float> dData(2, 6.0f);
557
558for (int i = 0; i < 2 * 3; ++i) {
559for (const auto j : c10::irange(M)) {
560bData[i * M + j] = j + 1;
561aData[i * M + j] = 6 - i;
562}
563}
564
565cg.call({aData, bData, dData, M});
566// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
567float expected[2] = {0, 0};
568for (const auto i : c10::irange(2)) {
569for (const auto j : c10::irange(3)) {
570for (const auto k : c10::irange(M)) {
571// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
572expected[i] += (k + 1) * (6 - (i * 3 + j));
573}
574}
575}
576
577for (const auto i : c10::irange(2)) {
578ASSERT_EQ(dData[i], expected[i]);
579}
580}
581
582TEST(Reductions, SplitReduceAxis) {
583BufHandle in("in", {16, 8}, kFloat);
584
585std::vector<float> in_(16 * 8);
586for (const auto i : c10::irange(16)) {
587for (const auto j : c10::irange(8)) {
588in_[i * 8 + j] = i;
589}
590}
591std::vector<float> out(16, -1.f);
592
593Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
594LoopNest l({tensor});
595std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
596LoopNest::splitWithTail(loops[1], 2);
597
598l.prepareForCodegen();
599
600StmtPtr s = l.root_stmt();
601s = IRSimplifier::simplify(s);
602
603SimpleIREvaluator cg(s, {in, tensor});
604cg.call({in_, out});
605
606for (const auto i : c10::irange(16)) {
607ASSERT_EQ(out[i], i * 8);
608}
609}
610
611TEST(Reductions, SplitNonReduceAxis) {
612BufHandle in("in", {16, 8}, kFloat);
613
614std::vector<float> in_(16 * 8);
615for (const auto i : c10::irange(16)) {
616for (const auto j : c10::irange(8)) {
617in_[i * 8 + j] = i;
618}
619}
620std::vector<float> out(16, -1.f);
621Tensor tensor = Reduce("sum", {16}, Sum(), in, {8});
622LoopNest l({tensor});
623std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
624LoopNest::splitWithTail(loops[0], 2);
625LoopNest::splitWithTail(loops[0], 2);
626
627l.prepareForCodegen();
628
629StmtPtr s = l.root_stmt();
630s = IRSimplifier::simplify(s);
631
632SimpleIREvaluator cg(s, {in, tensor});
633cg.call({in_, out});
634
635for (const auto i : c10::irange(16)) {
636ASSERT_EQ(out[i], i * 8);
637}
638}
639
640TEST(Reductions, ReorderedReductionInitializer) {
641/* From the quip:
642for k in 0..1: // blockIdx
643for m in 0..128:
644for n in 0..64: // threadIdx
645SumOp(c(k, n), 0, a(k, m, n), {m})
646*/
647
648BufHandle in("in", {1, 12, 6}, kFloat);
649std::vector<float> in_(12 * 6, 1.f);
650
651Tensor tensor_ = Reduce("sum", {1, 12}, Sum(), in, {6});
652LoopNest l_({tensor_});
653
654l_.prepareForCodegen();
655StmtPtr s_ = Stmt::clone(l_.root_stmt());
656s_ = IRSimplifier::simplify(s_);
657
658Tensor tensor = Reduce("sum", {1, 12}, Sum(), in, {6});
659LoopNest l({tensor});
660
661auto loops = l.getLoopStmtsFor(tensor);
662loops[0]->set_gpu_block_index(0);
663loops[1]->set_gpu_thread_index(0);
664
665LoopNest::reorderAxis(loops[1], loops[2]);
666
667StmtPtr s = l.root_stmt();
668// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
669s = IRSimplifier::simplify(s);
670
671l.prepareForCodegen();
672
673s = l.root_stmt();
674s = IRSimplifier::simplify(s);
675
676std::vector<float> out1(16, -1.f);
677SimpleIREvaluator cg(s_, {in, tensor_});
678cg.call({in_, out1});
679
680std::vector<float> out2(16, -1.f);
681SimpleIREvaluator cg2(s, {in, tensor});
682cg2.call({in_, out2});
683
684for (const auto i : c10::irange(16)) {
685ASSERT_EQ(out1[i], out2[i]);
686}
687}
688
689TEST(Reductions, ReduceRfactor) {
690const int M = 10;
691const int N = 10;
692VarHandle m("m", kInt);
693VarHandle n("n", kInt);
694
695BufHandle b("b", {m, n}, kFloat);
696std::vector<float> in(M * N);
697for (int j = 0; j < M * N; ++j) {
698in[j] = j;
699}
700
701std::vector<float> out(1, -1.f);
702
703Tensor c = Reduce("sum", {}, Sum(), b, {m, n});
704LoopNest loop({c});
705std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
706auto c_body = loop.getAllWritesToBuf(c.buf())[1];
707ASSERT_TRUE(loop.rfactor(c_body, loops.at(0)));
708auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
709ASSERT_EQ(rc.size(), 2);
710loop.prepareForCodegen();
711StmtPtr s = loop.root_stmt();
712s = IRSimplifier::simplify(s);
713
714SimpleIREvaluator cg(s, {b, c, m, n});
715
716cg.call({in, out, M, N});
717ASSERT_EQ(out[0], 4950);
718}
719
720TEST(Reductions, Reduce3DRfactorInner) {
721const int M = 10;
722const int N = 10;
723const int K = 10;
724VarHandle m("m", kInt);
725VarHandle n("n", kInt);
726VarHandle k("k", kInt);
727
728BufHandle b("b", {m, n, k}, kFloat);
729std::vector<float> in(M * N * K);
730for (int j = 0; j < M * N * K; ++j) {
731in[j] = j;
732}
733
734std::vector<float> out(1, -1.f);
735
736Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
737LoopNest loop({c});
738std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
739auto c_body = loop.getAllWritesToBuf(c.buf())[1];
740ASSERT_FALSE(loop.rfactor(c_body, loops.at(2)));
741auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
742ASSERT_EQ(rc.size(), 1);
743loop.prepareForCodegen();
744StmtPtr s = loop.root_stmt();
745s = IRSimplifier::simplify(s);
746
747SimpleIREvaluator cg(s, {b, c, m, n, k});
748
749cg.call({in, out, M, N, K});
750ASSERT_EQ(out[0], 499500);
751}
752
753TEST(Reductions, Reduce3DRfactorOuter) {
754const int M = 10;
755const int N = 10;
756const int K = 10;
757VarHandle m("m", kInt);
758VarHandle n("n", kInt);
759VarHandle k("k", kInt);
760
761BufHandle b("b", {m, n, k}, kFloat);
762std::vector<float> in(M * N * K);
763for (int j = 0; j < M * N * K; ++j) {
764in[j] = j;
765}
766
767std::vector<float> out(1, -1.f);
768
769Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
770LoopNest loop({c});
771std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
772auto c_body = loop.getAllWritesToBuf(c.buf())[1];
773ASSERT_TRUE(loop.rfactor(c_body, loops.at(0)));
774auto rc = NodeFinder<ReduceOp>::find(loop.root_stmt());
775ASSERT_EQ(rc.size(), 2);
776loop.prepareForCodegen();
777StmtPtr s = loop.root_stmt();
778s = IRSimplifier::simplify(s);
779
780SimpleIREvaluator cg(s, {b, c, m, n, k});
781cg.call({in, out, M, N, K});
782ASSERT_EQ(out[0], 499500);
783}
784
785TEST(Reductions, ReduceRepeatedInternalRfactor) {
786BufHandle in_("in_", {2, 3, 4, 5, 6}, kFloat);
787const int InputSize = 2 * 3 * 4 * 5 * 6;
788
789std::vector<float> in(InputSize, 1.f);
790std::vector<float> out(1, -1.f);
791std::vector<float> ref(1, -1.f);
792
793Tensor c = Reduce("sum", {}, Sum(), in_, {2, 3, 4, 5, 6});
794LoopNest orig_loop({c});
795
796// Try rfactoring N outer loops
797for (const auto rfac_number : c10::irange(1, 5)) {
798LoopNest refloop(orig_loop);
799LoopNest loop(orig_loop);
800refloop.prepareForCodegen();
801SimpleIREvaluator ref_cg(
802IRSimplifier::simplify(refloop.root_stmt()), {in_, c});
803ref_cg.call({in, ref});
804
805BufPtr tmp_buf = c.buf();
806
807for (const auto idx : c10::irange(rfac_number)) {
808auto reduce = loop.getAllWritesToBuf(tmp_buf)[1];
809ASSERT_TRUE(loop.rfactor(
810reduce, loop.getLoopStmtsFor(tmp_buf).at(idx), &tmp_buf));
811}
812
813loop.prepareForCodegen();
814StmtPtr s = loop.root_stmt();
815s = IRSimplifier::simplify(s);
816
817SimpleIREvaluator cg(s, {in_, c});
818cg.call({in, out});
819
820ASSERT_EQ(ref[0], out[0]);
821}
822}
823
824// Split a reduction axis with a tail loop.
825TEST(Reductions, ReduceSplitTail) {
826const int M = 10;
827const int N = 10;
828const int K = 10;
829
830BufHandle b("b", {M, N, K}, kFloat);
831std::vector<float> in(M * N * K);
832for (int j = 0; j < M * N * K; ++j) {
833in[j] = j;
834}
835
836for (const auto i : c10::irange(3)) {
837std::vector<float> out(M, -1.f);
838
839Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
840LoopNest loop({c});
841std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
842LoopNest::splitWithTail(loops[i], 8);
843
844loop.prepareForCodegen();
845StmtPtr s = loop.root_stmt();
846s = IRSimplifier::simplify(s);
847
848SimpleIREvaluator cg(s, {b, c});
849
850cg.call({in, out});
851ASSERT_EQ(out[0], 4950);
852}
853}
854
855// Split a reduction axis cleanly so there is no tail loop.
856TEST(Reductions, ReduceSplitNoTail) {
857const int M = 10;
858const int N = 10;
859const int K = 10;
860BufHandle b("b", {M, N, K}, kFloat);
861std::vector<float> in(M * N * K);
862for (int j = 0; j < M * N * K; ++j) {
863in[j] = j;
864}
865
866for (const auto i : c10::irange(3)) {
867std::vector<float> out(M, -1.f);
868
869Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
870LoopNest loop({c});
871std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
872LoopNest::splitWithTail(loops[i], 5);
873
874loop.prepareForCodegen();
875StmtPtr s = loop.root_stmt();
876s = IRSimplifier::simplify(s);
877
878SimpleIREvaluator cg(s, {b, c});
879
880cg.call({in, out});
881ASSERT_EQ(out[0], 4950);
882}
883}
884
885// Split a reduction axis with only a tail loop (the split loop will be size 0
886// and eliminated out).
887TEST(Reductions, ReduceOverSplitTail) {
888const int M = 10;
889const int N = 10;
890const int K = 10;
891
892BufHandle b("b", {M, N, K}, kFloat);
893std::vector<float> in(M * N * K);
894for (int j = 0; j < M * N * K; ++j) {
895in[j] = j;
896}
897
898for (const auto i : c10::irange(3)) {
899std::vector<float> out(M, -1.f);
900
901Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
902LoopNest loop({c});
903std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
904LoopNest::splitWithTail(loops[i], 16);
905
906loop.prepareForCodegen();
907StmtPtr s = loop.root_stmt();
908s = IRSimplifier::simplify(s);
909
910SimpleIREvaluator cg(s, {b, c});
911
912cg.call({in, out});
913ASSERT_EQ(out[0], 4950);
914}
915}
916
917// Split a reduction axis with a mask.
918TEST(Reductions, ReduceSplitMask) {
919const int M = 10;
920const int N = 10;
921const int K = 10;
922
923BufHandle b("b", {M, N, K}, kFloat);
924std::vector<float> in(M * N * K);
925for (int j = 0; j < M * N * K; ++j) {
926in[j] = j;
927}
928
929for (const auto i : c10::irange(3)) {
930std::vector<float> out(M, -1.f);
931
932Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
933LoopNest loop({c});
934std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
935LoopNest::splitWithMask(loops[i], 8);
936
937loop.prepareForCodegen();
938StmtPtr s = loop.root_stmt();
939s = IRSimplifier::simplify(s);
940
941SimpleIREvaluator cg(s, {b, c});
942
943cg.call({in, out});
944ASSERT_EQ(out[0], 4950);
945}
946}
947
948// Split a reduction axis cleanly not requiring a mask.
949TEST(Reductions, ReduceSplitNoMask) {
950const int M = 10;
951const int N = 10;
952const int K = 10;
953BufHandle b("b", {M, N, K}, kFloat);
954std::vector<float> in(M * N * K);
955for (int j = 0; j < M * N * K; ++j) {
956in[j] = j;
957}
958
959for (const auto i : c10::irange(3)) {
960std::vector<float> out(M, -1.f);
961
962Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
963LoopNest loop({c});
964std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
965LoopNest::splitWithMask(loops[i], 5);
966
967loop.prepareForCodegen();
968StmtPtr s = loop.root_stmt();
969s = IRSimplifier::simplify(s);
970
971SimpleIREvaluator cg(s, {b, c});
972
973cg.call({in, out});
974ASSERT_EQ(out[0], 4950);
975}
976}
977
978// Split a reduction axis with all logic in the mask.
979TEST(Reductions, ReduceOverSplitMask) {
980const int M = 10;
981const int N = 10;
982const int K = 10;
983
984BufHandle b("b", {M, N, K}, kFloat);
985std::vector<float> in(M * N * K);
986for (int j = 0; j < M * N * K; ++j) {
987in[j] = j;
988}
989
990for (const auto i : c10::irange(3)) {
991std::vector<float> out(M, -1.f);
992
993Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
994LoopNest loop({c});
995std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
996LoopNest::splitWithMask(loops[i], 16);
997
998loop.prepareForCodegen();
999StmtPtr s = loop.root_stmt();
1000s = IRSimplifier::simplify(s);
1001
1002SimpleIREvaluator cg(s, {b, c});
1003
1004cg.call({in, out});
1005ASSERT_EQ(out[0], 4950);
1006}
1007}
1008
1009// Test an rfactor when there are two ReduceOps in the graph due to a
1010// splitWithTail.
1011TEST(Reductions, ReduceSplitRfactor) {
1012const int M = 2;
1013const int N = 10;
1014const int K = 10;
1015const int SPLIT_FACTOR = 4;
1016
1017BufHandle b("b", {M, N, K}, kFloat);
1018std::vector<float> in(M * N * K);
1019for (const auto m : c10::irange(M)) {
1020for (int j = 0; j < N * K; ++j) {
1021in[m * N * K + j] = j;
1022}
1023}
1024
1025std::vector<float> out(M, -1.f);
1026
1027Tensor c = Reduce("sum", {M}, Sum(), b, {N, K});
1028LoopNest loop({c});
1029std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1030LoopNest::splitWithTail(loops[2], SPLIT_FACTOR);
1031
1032auto c_body = loop.getAllWritesToBuf(c.buf())[2];
1033auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1034ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3);
1035LoopNest::reorderAxis(all_loops[2][1], all_loops[2][2]);
1036all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1037ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(2).size() == 3);
1038ASSERT_TRUE(loop.rfactor(c_body, all_loops[2][1]));
1039loop.prepareForCodegen();
1040loop.simplify();
1041StmtPtr s = loop.root_stmt();
1042
1043SimpleIREvaluator cg(s, {b, c});
1044
1045cg.call({in, out});
1046for (const auto i : c10::irange(M)) {
1047(void)i; // Suppress unused variable warning
1048ASSERT_EQ(out[0], 4950);
1049}
1050}
1051
1052// Test an rfactor which ends up being eliminated since the total loop size is
1053// smaller than the split factor.
1054TEST(Reductions, ReduceOverSplitRfactor) {
1055const int N = 10;
1056const int K = 10;
1057const int SPLIT_FACTOR = 16;
1058
1059BufHandle b("b", {N, K}, kFloat);
1060std::vector<float> in(N * K);
1061for (int j = 0; j < N * K; ++j) {
1062in[j] = j;
1063}
1064
1065std::vector<float> out(1, -1.f);
1066
1067Tensor c = Reduce("sum", {}, Sum(), b, {N, K});
1068LoopNest loop({c});
1069std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1070// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1071ForPtr i, t;
1072LoopNest::splitWithTail(loops[1], SPLIT_FACTOR, &i, &t);
1073LoopNest::reorderAxis(loops[0], i);
1074
1075auto all_loops = loop.getAllLoopNestsWritingToBuf(c.buf());
1076ASSERT_TRUE(all_loops.size() == 3 && all_loops.at(1).size() == 3);
1077auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1078ASSERT_TRUE(loop.rfactor(c_body, all_loops[1][0]));
1079LoopNest::reorderAxis(all_loops[1][0], all_loops[1][2]);
1080
1081loop.prepareForCodegen();
1082loop.simplify();
1083StmtPtr s = loop.root_stmt();
1084
1085SimpleIREvaluator cg(s, {b, c});
1086
1087cg.call({in, out});
1088ASSERT_EQ(out[0], 4950);
1089
1090std::ostringstream oss;
1091oss << *cg.stmt();
1092
1093// Check the IR to verify the rfactored reduce is eliminated.
1094// TODO: The alloc free should be eliminated here since it is size 0.
1095/*
1096const std::string& verification_pattern =
1097R"IR(
1098# CHECK: Allocate(tmp_buf); // dtype=float, dims=[0]
1099# CHECK: sum[0] = 0.f;
1100# CHECK: for (int n = 0; n < 10; n++) {
1101# CHECK: for (int k_tail = 0; k_tail < 10; k_tail++) {
1102# CHECK: sum[0] = (sum[0]) + (b[k_tail + 10 * n]);
1103# CHECK: }
1104# CHECK: }
1105# CHECK: Free(tmp_buf);)IR";
1106*/
1107// TODO: rfactor output is not consistent yet, will fix (@nickg).
1108// torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
1109}
1110
1111TEST(Reductions, ReduceInlineReduction) {
1112const int M = 4;
1113const int N = 5;
1114const int K = 6;
1115
1116BufHandle a_buf("a", {M}, kFloat);
1117BufHandle b_buf("b", {M, N, K}, kFloat);
1118
1119Tensor x = Reduce("x", {M}, Sum(), b_buf, {N, K});
1120Tensor y = Compute(
1121"y", {M}, [&](const VarHandle& m) { return a_buf.load(m) + x.load(m); });
1122
1123PaddedBuffer<float> a_v(M);
1124PaddedBuffer<float> b_v(M, N, K);
1125
1126for (const auto i : c10::irange(M)) {
1127a_v(i) = i * i;
1128}
1129for (const auto i : c10::irange(M)) {
1130for (const auto j : c10::irange(N)) {
1131for (const auto k : c10::irange(K)) {
1132b_v(i, j, k) = j * j * k;
1133}
1134}
1135}
1136
1137LoopNest l1({y}, {x, y});
1138// Cannot inline a reduction computation
1139ASSERT_FALSE(l1.computeInline(x.buf()));
1140}
1141
1142TEST(Reductions, ReduceInlineConsumer) {
1143const int M = 4;
1144const int N = 5;
1145const int K = 6;
1146
1147BufHandle a_buf("a", {M, N, K}, kFloat);
1148BufHandle b_buf("b", {M, N, K}, kFloat);
1149
1150Tensor x = Compute(
1151"x",
1152{M, N, K},
1153[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1154return a_buf.load(m, n, k) + b_buf.load(m, n, k);
1155});
1156Tensor y = Reduce("y", {M}, Sum(), x, {N, K});
1157
1158PaddedBuffer<float> a_v(M, N, K);
1159PaddedBuffer<float> b_v(M, N, K);
1160
1161for (const auto i : c10::irange(M)) {
1162for (const auto j : c10::irange(N)) {
1163for (const auto k : c10::irange(K)) {
1164a_v(i, j, k) = i * i + k;
1165b_v(i, j, k) = j * j + k;
1166}
1167}
1168}
1169
1170LoopNest l1({y}, {x, y});
1171LoopNest l2(l1);
1172l2.computeInline(x.buf());
1173
1174l1.prepareForCodegen();
1175l2.prepareForCodegen();
1176
1177StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1178StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1179
1180SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
1181SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
1182
1183PaddedBuffer<float> y_1(M);
1184PaddedBuffer<float> y_2(M);
1185
1186eval1(a_v, b_v, y_1);
1187eval2(a_v, b_v, y_2);
1188ExpectAllNear(y_1, y_2, 1e-5);
1189std::ostringstream oss1, oss2;
1190oss1 << *stmt1;
1191oss2 << *stmt2;
1192ASSERT_GT(oss1.str().size(), oss2.str().size());
1193}
1194
1195TEST(Reductions, ReduceInlineReducerInternal) {
1196const int M = 4;
1197const int N = 5;
1198const int K = 6;
1199
1200BufHandle a_buf("a", {M, N, K}, kFloat);
1201BufHandle b_buf("b", {M, N, K}, kFloat);
1202
1203Tensor x = Compute(
1204"x",
1205{M, N, K},
1206[&](const VarHandle& m, const VarHandle& n, const VarHandle& k) {
1207return a_buf.load(m, n, k) + b_buf.load(m, n, k);
1208});
1209
1210Reducer minimum(ExprHandle(0.f), [&](ExprHandle a, ExprHandle b) {
1211return Add::make(ExprHandle(1.f), Min::make(a, b, false));
1212});
1213Tensor y = Reduce("y", {M}, minimum, x, {N, K});
1214
1215PaddedBuffer<float> a_v(M, N, K);
1216PaddedBuffer<float> b_v(M, N, K);
1217
1218for (const auto i : c10::irange(M)) {
1219for (const auto j : c10::irange(N)) {
1220for (const auto k : c10::irange(K)) {
1221a_v(i, j, k) = i * i + k;
1222b_v(i, j, k) = j * j + k;
1223}
1224}
1225}
1226
1227LoopNest l1({y}, {x, y});
1228LoopNest l2(l1);
1229l2.computeInline(x.buf());
1230
1231l1.prepareForCodegen();
1232l2.prepareForCodegen();
1233
1234StmtPtr stmt1 = IRSimplifier::simplify(l1.root_stmt());
1235StmtPtr stmt2 = IRSimplifier::simplify(l2.root_stmt());
1236
1237SimpleIREvaluator eval1(stmt1, {a_buf, b_buf, y});
1238SimpleIREvaluator eval2(stmt2, {a_buf, b_buf, y});
1239
1240PaddedBuffer<float> y_1(M);
1241PaddedBuffer<float> y_2(M);
1242
1243eval1(a_v, b_v, y_1);
1244eval2(a_v, b_v, y_2);
1245ExpectAllNear(y_1, y_2, 1e-5);
1246std::ostringstream oss1, oss2;
1247oss1 << *stmt1;
1248oss2 << *stmt2;
1249ASSERT_GT(oss1.str().size(), oss2.str().size());
1250}
1251
1252TEST(Reductions, ReductionCacheAccessesOperatorAxis) {
1253int L = 4;
1254int N = 3;
1255int M = 2;
1256
1257BufHandle a("a", {L, N, M}, kFloat);
1258BufHandle b("b", {L, N, M}, kFloat);
1259
1260Tensor c = Compute(
1261"scale",
1262{L, N, M},
1263[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1264return b.load(l, n, m) * a.load(l, n, m);
1265});
1266Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1267
1268Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1269return b.load(0, 0, l) * d.load(l);
1270});
1271
1272LoopNest l({e}, {c, d, e});
1273LoopNest l_before(l);
1274l_before.prepareForCodegen();
1275SimpleIREvaluator cg_before(
1276LoopNest::sanitizeNames(l_before.root_stmt()), {a, b, e});
1277
1278StmtPtr d_loop = l.getLoopStmtsFor(d)[0];
1279l.cacheAccesses(d.buf(), "d_local", d_loop);
1280l.prepareForCodegen();
1281
1282StmtPtr result =
1283LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1284SimpleIREvaluator cg_after(result, {a, b, e});
1285
1286std::ostringstream oss;
1287oss << *cg_after.stmt();
1288const std::string& expected_ir =
1289R"IR(
1290#CHECK: Allocate(d_local); // dtype=float, dims=[4]
1291#CHECK: for (int i_2
1292#CHECK: d_local[i_2] = 0.f
1293#CHECK: for (int
1294#CHECK: for (int
1295#CHECK: d_local[i_2] = (d_local[i_2]) + (scale[
1296#CHECK: }
1297#CHECK: }
1298#CHECK: }
1299#CHECK: for (int i_3
1300#CHECK: sum[i_3] = d_local[i_3]
1301#CHECK: Free(d_local);
1302#CHECK-NOT: d_local
1303)IR";
1304torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1305
1306PaddedBuffer<float> a_v(L, M, N, "a");
1307PaddedBuffer<float> b_v(L, M, N, "b");
1308PaddedBuffer<float> c_v(L, M, N, "c");
1309PaddedBuffer<float> d_v(L, "d");
1310PaddedBuffer<float> e_before(L, "e_before");
1311PaddedBuffer<float> e_after(L, "e_after");
1312
1313for (const auto l : c10::irange(L)) {
1314for (const auto m : c10::irange(M)) {
1315for (const auto n : c10::irange(N)) {
1316a_v(l, m, n) = at::randn({1}).item().to<float>();
1317b_v(l, m, n) = at::randn({1}).item().to<float>();
1318}
1319}
1320}
1321
1322cg_before.call({a_v, b_v, e_before});
1323cg_after.call({a_v, b_v, e_after});
1324
1325// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1326ExpectAllNear(e_before, e_after, 1e-5);
1327}
1328
1329TEST(Reductions, ReductionCacheAccessesOuterReduceAxis) {
1330int L = 4;
1331int N = 3;
1332int M = 2;
1333
1334BufHandle a("a", {L, N, M}, kFloat);
1335BufHandle b("b", {L, N, M}, kFloat);
1336
1337Tensor c = Compute(
1338"scale",
1339{L, N, M},
1340[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1341return b.load(l, n, m) * a.load(l, n, m);
1342});
1343Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1344
1345Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1346return b.load(0, 0, l) * d.load(l);
1347});
1348
1349LoopNest l({e}, {c, d, e});
1350LoopNest l_before(l);
1351l_before.prepareForCodegen();
1352SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
1353
1354StmtPtr d_loop = l.getLoopStmtsFor(d)[1];
1355l.cacheAccesses(d.buf(), "d_local", d_loop);
1356l.prepareForCodegen();
1357
1358StmtPtr result =
1359LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1360SimpleIREvaluator cg_after(result, {a, b, e});
1361
1362std::ostringstream oss;
1363oss << *cg_after.stmt();
1364const std::string& expected_ir =
1365R"IR(
1366#CHECK: Allocate(d_local); // dtype=float, dims=[1]
1367#CHECK: sum[i_1] = 0
1368#CHECK: d_local[0] = sum[i_1]
1369#CHECK: for (int j_1
1370#CHECK: for (int k_1
1371#CHECK: d_local[0] = (d_local[0]) + (scale[
1372#CHECK: }
1373#CHECK: }
1374#CHECK: sum[i_1] = d_local[0]
1375#CHECK: Free(d_local);
1376#CHECK-NOT: d_local
1377)IR";
1378torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1379
1380PaddedBuffer<float> a_v(L, M, N, "a");
1381PaddedBuffer<float> b_v(L, M, N, "b");
1382PaddedBuffer<float> c_v(L, M, N, "c");
1383PaddedBuffer<float> d_v(L, "d");
1384PaddedBuffer<float> e_before(L, "e_before");
1385PaddedBuffer<float> e_after(L, "e_after");
1386
1387for (const auto l : c10::irange(L)) {
1388for (const auto m : c10::irange(M)) {
1389for (const auto n : c10::irange(N)) {
1390a_v(l, m, n) = at::randn({1}).item().to<float>();
1391b_v(l, m, n) = at::randn({1}).item().to<float>();
1392}
1393}
1394}
1395
1396cg_before.call({a_v, b_v, e_before});
1397cg_after.call({a_v, b_v, e_after});
1398
1399// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1400ExpectAllNear(e_before, e_after, 1e-5);
1401}
1402
1403TEST(Reductions, ReductionCacheAccessesInnerReduceAxis) {
1404int L = 4;
1405int N = 3;
1406int M = 2;
1407
1408BufHandle a("a", {L, N, M}, kFloat);
1409BufHandle b("b", {L, N, M}, kFloat);
1410
1411Tensor c = Compute(
1412"scale",
1413{L, N, M},
1414[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1415return b.load(l, n, m) * a.load(l, n, m);
1416});
1417Tensor d = Reduce("sum", {L}, Sum(), c, {N, M});
1418
1419Tensor e = Compute("scale", {L}, [&](const VarHandle& l) {
1420return b.load(0, 0, l) * d.load(l);
1421});
1422
1423LoopNest l({e}, {c, d, e});
1424LoopNest l_before(l);
1425l_before.prepareForCodegen();
1426SimpleIREvaluator cg_before(l_before.root_stmt(), {a, b, e});
1427
1428StmtPtr d_loop = l.getLoopStmtsFor(d)[2];
1429l.cacheAccesses(d.buf(), "d_local", d_loop);
1430l.prepareForCodegen();
1431
1432StmtPtr result =
1433LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1434SimpleIREvaluator cg_after(result, {a, b, e});
1435
1436std::ostringstream oss;
1437oss << *cg_after.stmt();
1438const std::string& expected_ir =
1439R"IR(
1440#CHECK: Allocate(d_local); // dtype=float, dims=[1]
1441#CHECK: sum[i_1] = 0
1442#CHECK: for (int
1443#CHECK: d_local[0] = 0
1444#CHECK: for (int
1445#CHECK: d_local[0] = (d_local[0]) + (scale[
1446#CHECK: }
1447#CHECK: sum[i_1] = (sum[i_1]) + (d_local[0])
1448#CHECK: }
1449#CHECK: Free(d_local);
1450#CHECK-NOT: d_local
1451)IR";
1452torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1453
1454PaddedBuffer<float> a_v(L, M, N, "a");
1455PaddedBuffer<float> b_v(L, M, N, "b");
1456PaddedBuffer<float> c_v(L, M, N, "c");
1457PaddedBuffer<float> d_v(L, "d");
1458PaddedBuffer<float> e_before(L, "e_before");
1459PaddedBuffer<float> e_after(L, "e_after");
1460
1461for (const auto l : c10::irange(L)) {
1462for (const auto m : c10::irange(M)) {
1463for (const auto n : c10::irange(N)) {
1464a_v(l, m, n) = at::randn({1}).item().to<float>();
1465b_v(l, m, n) = at::randn({1}).item().to<float>();
1466}
1467}
1468}
1469
1470cg_before.call({a_v, b_v, e_before});
1471cg_after.call({a_v, b_v, e_after});
1472
1473// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
1474ExpectAllNear(e_before, e_after, 1e-5);
1475}
1476
1477TEST(Reductions, ReductionCacheBodyAccess) {
1478BufHandle a("a", {24, 32, 12}, kFloat);
1479BufHandle b("b", {24, 32, 12}, kFloat);
1480
1481Tensor c = Compute(
1482"scale",
1483{24, 32, 12},
1484[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1485return b.load(l, n, m) * a.load(l, n, m);
1486});
1487Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1488
1489Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1490return b.load(0, 0, l) * d.load(l);
1491});
1492
1493LoopNest l({e}, {c, d, e});
1494
1495StmtPtr d_loop = l.getLoopStmtsFor(d)[1];
1496l.cacheAccesses(c.buf(), "scale_local", d_loop);
1497
1498l.prepareForCodegen();
1499StmtPtr result =
1500LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1501SimpleIREvaluator cg(result, {a, b, e});
1502
1503std::ostringstream oss;
1504oss << *cg.stmt();
1505const std::string& expected_ir =
1506R"IR(
1507#CHECK: Allocate(scale_local); // dtype=float, dims=[1, 32, 12]
1508#CHECK: for (int j_1 = 0; j_1 < 32; j_1++) {
1509#CHECK: for (int k_1 = 0; k_1 < 12; k_1++) {
1510#CHECK: scale_local[k_1 + 12 * j_1] = scale[(k_1 + 12 * j_1) + 384 * i_1];
1511#CHECK: sum[i_1] = (sum[i_1]) + (scale_local[k_2 + 12 * j_2]);
1512#CHECK: scale_1[i_2] = (b[i_2]) * (sum[i_2]);
1513#CHECK: Free(scale_local);
1514)IR";
1515torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1516}
1517
1518TEST(Reductions, ReductionCacheConsumerAccess) {
1519BufHandle a("a", {24, 32, 12}, kFloat);
1520BufHandle b("b", {24, 32, 12}, kFloat);
1521
1522Tensor c = Compute(
1523"scale",
1524{24, 32, 12},
1525[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1526return b.load(l, n, m) * a.load(l, n, m);
1527});
1528Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1529
1530Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1531return b.load(0, 0, l) * d.load(l);
1532});
1533
1534LoopNest l({e}, {c, d, e});
1535
1536LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4);
1537
1538StmtPtr e_loop = l.getLoopStmtsFor(e)[1];
1539l.cacheAccesses(d.buf(), "sum_local", e_loop);
1540l.prepareForCodegen();
1541
1542StmtPtr result =
1543LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1544SimpleIREvaluator cg(result, {a, b, e});
1545
1546std::ostringstream oss;
1547oss << *cg.stmt();
1548const std::string& expected_ir =
1549R"IR(
1550#CHECK: Alias(sum_local,scale);
1551#CHECK: sum[i_1] = (sum[i_1]) + (scale[
1552#CHECK: for (int j_2 = 0; j_2 < 4
1553#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2];
1554#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
1555)IR";
1556torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1557}
1558
1559TEST(Reductions, ReductionSplitCacheConsumerAccess) {
1560BufHandle a("a", {24, 32, 12}, kFloat);
1561BufHandle b("b", {24, 32, 12}, kFloat);
1562
1563Tensor c = Compute(
1564"scale",
1565{24, 32, 12},
1566[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1567return b.load(l, n, m) * a.load(l, n, m);
1568});
1569Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1570
1571Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1572return b.load(0, 0, l) * d.load(l);
1573});
1574
1575LoopNest l({e}, {c, d, e});
1576
1577// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1578ForPtr inner;
1579
1580// Split outer reduction axis.
1581LoopNest::splitWithMask(l.getLoopStmtsFor(d)[0], 4, &inner);
1582
1583// Split reduction consumer.
1584LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner);
1585
1586l.cacheAccesses(d.buf(), "sum_local", inner);
1587l.prepareForCodegen();
1588
1589StmtPtr result =
1590LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1591SimpleIREvaluator cg(result, {a, b, e});
1592
1593// reduction changes but cache does not.
1594std::ostringstream oss;
1595oss << *cg.stmt();
1596const std::string& expected_ir =
1597R"IR(
1598#CHECK: Alias(sum_local,scale);
1599#CHECK: sum[j_1 + 4 * i_1] = (sum[j_1 + 4 * i_1]) + (scale[((l + 12 * k_1) + 1536 * i_1) + 384 * j_1]);
1600#CHECK: for (int i_2 = 0; i_2 < 6
1601#CHECK: for (int j_2 = 0; j_2 < 4
1602#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_2];
1603#CHECK: for (int j_3 = 0; j_3 < 4
1604#CHECK: scale_1[j_3 + 4 * i_2] = (b[j_3 + 4 * i_2]) * (sum_local[j_3]);
1605)IR";
1606torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1607}
1608
1609TEST(Reductions, ReductionReorderCacheConsumerAccess) {
1610BufHandle a("a", {24, 32, 12}, kFloat);
1611BufHandle b("b", {24, 32, 12}, kFloat);
1612
1613Tensor c = Compute(
1614"scale",
1615{24, 32, 12},
1616[&](const VarHandle& l, const VarHandle& n, const VarHandle& m) {
1617return b.load(l, n, m) * a.load(l, n, m);
1618});
1619Tensor d = Reduce("sum", {24}, Sum(), c, {32, 12});
1620
1621Tensor e = Compute("scale", {24}, [&](const VarHandle& l) {
1622return b.load(0, 0, l) * d.load(l);
1623});
1624
1625LoopNest l({e}, {c, d, e});
1626
1627// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1628ForPtr inner;
1629
1630// reorder outer reduction axes.
1631auto loops = l.getLoopStmtsFor(d);
1632LoopNest::reorderAxis(loops[0], loops[1]);
1633
1634// Split reduction consumer.
1635LoopNest::splitWithMask(l.getLoopStmtsFor(e)[0], 4, &inner);
1636
1637l.cacheAccesses(d.buf(), "sum_local", inner);
1638l.prepareForCodegen();
1639
1640StmtPtr result =
1641LoopNest::sanitizeNames(IRSimplifier::simplify(l.root_stmt()));
1642SimpleIREvaluator cg(result, {a, b, e});
1643
1644// neither reduction body not cache changes.
1645std::ostringstream oss;
1646oss << *cg.stmt();
1647const std::string& expected_ir =
1648R"IR(
1649#CHECK: sum[j_1] = (sum[j_1]) + (scale[(k_1 + 12 * i_2) + 384 * j_1]);
1650#CHECK: for (int i_3 = 0; i_3 < 6;
1651#CHECK: for (int j_2 = 0; j_2 < 4;
1652#CHECK: sum_local[j_2] = sum[j_2 + 4 * i_3];
1653#CHECK: for (int j_3 = 0; j_3 < 4;
1654#CHECK: scale_1[j_3 + 4 * i_3] = (b[j_3 + 4 * i_3]) * (sum_local[j_3]);
1655)IR";
1656torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1657}
1658
1659TEST(Reductions, ReductionRfactorCacheTempOuter) {
1660const int M = 10;
1661const int N = 10;
1662const int K = 10;
1663VarHandle m("m", kInt);
1664VarHandle n("n", kInt);
1665VarHandle k("k", kInt);
1666
1667BufHandle b("B", {m, n, k}, kFloat);
1668std::vector<float> in(M * N * K);
1669for (int j = 0; j < M * N * K; ++j) {
1670in[j] = j;
1671}
1672
1673std::vector<float> out(1, -1.f);
1674
1675Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
1676LoopNest loop({c});
1677
1678std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1679LoopNest::reorderAxis(loops.at(0), loops.at(1));
1680loops = loop.getLoopStmtsFor(c);
1681auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1682// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1683BufPtr rfac_buf;
1684ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf));
1685loop.distributeLoop(loops.at(0));
1686
1687auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1688ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1689LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]);
1690
1691all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1692LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][1]);
1693loop.simplify();
1694loop.prepareForCodegen();
1695StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
1696SimpleIREvaluator cg(s, {b, c, m, n, k});
1697
1698std::ostringstream oss;
1699oss << *cg.stmt();
1700const std::string& expected_ir =
1701R"IR(
1702#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
1703#CHECK: Allocate(tmp); // dtype=float, dims=[n]
1704#CHECK: for (int i_1 = 0; i_1 < m
1705#CHECK: for (int j = 0; j < n
1706#CHECK: tmp[j] = 0
1707#CHECK: }
1708#CHECK: for (int j_1 = 0; j_1 < n
1709#CHECK: for (int k
1710#CHECK: tmp[j_1] = (tmp[j_1]) + (B[
1711#CHECK: }
1712#CHECK: }
1713#CHECK: for (int j_2 = 0; j_2 < n
1714#CHECK: sum_rfac[j_2] = (sum_rfac[j_2]) + (tmp[j_2]);
1715#CHECK: }
1716#CHECK: Free(tmp);
1717#CHECK-NOT: tmp
1718)IR";
1719torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1720
1721cg.call({in, out, M, N, K});
1722ASSERT_EQ(out[0], 499500);
1723}
1724
1725TEST(Reductions, ReductionRfactorCacheTempInner) {
1726const int M = 10;
1727const int N = 10;
1728const int K = 10;
1729VarHandle m("m", kInt);
1730VarHandle n("n", kInt);
1731VarHandle k("k", kInt);
1732
1733BufHandle b("B", {m, n, k}, kFloat);
1734std::vector<float> in(M * N * K);
1735for (int j = 0; j < M * N * K; ++j) {
1736in[j] = j;
1737}
1738
1739std::vector<float> out(1, -1.f);
1740
1741Tensor c = Reduce("sum", {}, Sum(), b, {m, n, k});
1742LoopNest loop({c});
1743std::vector<ForPtr> loops = loop.getLoopStmtsFor(c);
1744auto c_body = loop.getAllWritesToBuf(c.buf())[1];
1745
1746LoopNest::reorderAxis(loops.at(0), loops.at(1));
1747loops = loop.getLoopStmtsFor(c);
1748// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
1749BufPtr rfac_buf;
1750ASSERT_TRUE(loop.rfactor(c_body, loops.at(0), &rfac_buf));
1751loop.distributeLoop(loops.at(0));
1752auto all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1753ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1754LoopNest::reorderAxis(all_loops[1][0], all_loops[1][1]);
1755
1756all_loops = loop.getAllLoopNestsWritingToBuf(rfac_buf);
1757ASSERT_TRUE(all_loops.size() == 2 && all_loops.at(1).size() == 3);
1758LoopNest::cacheAccesses(rfac_buf, "tmp", all_loops[1][2]);
1759loop.prepareForCodegen();
1760loop.simplify();
1761StmtPtr s = LoopNest::sanitizeNames(loop.root_stmt());
1762SimpleIREvaluator cg(s, {b, c, m, n, k});
1763
1764std::ostringstream oss;
1765oss << *cg.stmt();
1766const std::string& expected_ir =
1767R"IR(
1768#CHECK: Allocate(sum_rfac); // dtype=float, dims=[n]
1769#CHECK: Allocate(tmp); // dtype=float, dims=[1]
1770#CHECK: for (int i_1 = 0; i_1 < m
1771#CHECK: for (int j = 0; j < n
1772#CHECK: tmp[0] = 0
1773#CHECK: for (int k
1774#CHECK: tmp[0] = (tmp[0]) + (B[
1775#CHECK: }
1776#CHECK: sum_rfac[j] = (sum_rfac[j]) + (tmp[0]);
1777#CHECK: Free(tmp);
1778#CHECK-NOT: tmp
1779)IR";
1780torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1781
1782cg.call({in, out, M, N, K});
1783ASSERT_EQ(out[0], 499500);
1784}
1785
1786TEST(Reductions, ReductionVectorize) {
1787std::vector<float> in_(8 * 8);
1788for (const auto i : c10::irange(8)) {
1789for (const auto j : c10::irange(8)) {
1790in_[i * 8 + j] = i;
1791}
1792}
1793std::vector<float> out_before(8, -1.f);
1794std::vector<float> out_after(8, -1.f);
1795
1796BufHandle in("in", {8, 8}, kFloat);
1797
1798Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
1799LoopNest l_before({tensor});
1800LoopNest l(l_before);
1801l_before.prepareForCodegen();
1802SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor});
1803cg_before.call({in_, out_before});
1804
1805ASSERT_TRUE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[0]));
1806
1807StmtPtr s = l.root_stmt();
1808s = LoopNest::sanitizeNames(IRSimplifier::simplify(s));
1809
1810std::ostringstream oss;
1811oss << *s;
1812const std::string& expected_ir =
1813R"IR(
1814#CHECK: sum[Ramp(0, 1, 8)] = Broadcast(0.f, 8);
1815#CHECK: for (int i = 0; i < 8; i++) {
1816#CHECK: sum[Ramp(0, 1, 8)] = ReduceOp((sum[Ramp(0, 1, 8)]) + (in[Ramp(i, 8, 8)]), reduce_args={i});
1817#CHECK: }
1818)IR";
1819torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1820
1821// Vectorizing should not change result.
1822l.prepareForCodegen();
1823s = IRSimplifier::simplify(l.root_stmt());
1824SimpleIREvaluator cg_after(s, {in, tensor});
1825cg_after.call({in_, out_after});
1826for (const auto i : c10::irange(8)) {
1827ASSERT_EQ(out_before[i], out_after[i]);
1828}
1829}
1830
1831TEST(Reductions, ReductionVectorizeInner) {
1832BufHandle in("in", {8, 8}, kFloat);
1833
1834Tensor tensor = Reduce("sum", {8}, Sum(), in, {8});
1835LoopNest l({tensor});
1836
1837ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1]));
1838}
1839
1840TEST(Reductions, ReductionVectorizeRfactor) {
1841std::vector<float> in_(8 * 8);
1842for (const auto i : c10::irange(8)) {
1843for (const auto j : c10::irange(8)) {
1844in_[i * 8 + j] = i;
1845}
1846}
1847std::vector<float> out_before(1, -1.f);
1848std::vector<float> out_after(1, -1.f);
1849
1850BufHandle in("in", {8, 8}, kFloat);
1851
1852Tensor tensor = Reduce("sum", {}, Sum(), in, {8, 8});
1853
1854LoopNest l_before({tensor});
1855LoopNest l(l_before);
1856l_before.prepareForCodegen();
1857SimpleIREvaluator cg_before(l_before.root_stmt(), {in, tensor});
1858cg_before.call({in_, out_before});
1859
1860ASSERT_FALSE(LoopNest::vectorize(l.getLoopStmtsFor(tensor)[1]));
1861
1862// But if we rfactor this so it's not a reduce axis we can vectorize that
1863// loop.
1864std::vector<ForPtr> loops = l.getLoopStmtsFor(tensor);
1865LoopNest::reorderAxis(loops[0], loops[1]);
1866loops = l.getLoopStmtsFor(tensor);
1867auto tensor_body = l.getAllWritesToBuf(tensor.buf())[1];
1868BufPtr rfac_buf = nullptr;
1869ASSERT_TRUE(LoopNest::rfactor(tensor_body, loops.at(0), &rfac_buf));
1870
1871LoopNest::distributeLoop(loops.at(0));
1872auto rfac_loops = l.getAllLoopNestsWritingToBuf(rfac_buf);
1873
1874ASSERT_TRUE(LoopNest::vectorize(rfac_loops[1][0]));
1875l.simplify();
1876
1877StmtPtr s = LoopNest::sanitizeNames(l.root_stmt());
1878
1879std::ostringstream oss;
1880oss << *s;
1881const std::string& expected_ir =
1882R"IR(
1883#CHECK: sum = 0.f;
1884#CHECK: for (int i = 0; i < 8; i++) {
1885#CHECK: sum_rfac[i] = 0.f;
1886#CHECK: }
1887#CHECK: for (int i_1 = 0; i_1 < 8; i_1++) {
1888#CHECK: sum_rfac[Ramp(0, 1, 8)] = ReduceOp((sum_rfac[Ramp(0, 1, 8)]) + (in[Ramp(8 * i_1, 1, 8)]), reduce_args={i_1});
1889#CHECK: }
1890#CHECK: for (int i_2 = 0; i_2 < 8; i_2++) {
1891#CHECK: sum = ReduceOp((sum) + (sum_rfac[i_2]), reduce_args={i_2});
1892#CHECK: }
1893)IR";
1894torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1895
1896// Vectorizing should not change result.
1897l.prepareForCodegen();
1898s = IRSimplifier::simplify(l.root_stmt());
1899SimpleIREvaluator cg_after(s, {in, tensor});
1900cg_after.call({in_, out_after});
1901
1902ASSERT_EQ(out_before[0], out_after[0]);
1903}
1904
1905TEST(Reductions, InitFunction) {
1906constexpr int M = 32;
1907constexpr int N = 16;
1908BufHandle A("A", {M, N}, kFloat);
1909BufHandle B("B", {N}, kFloat);
1910Tensor C = Reduce(
1911"C",
1912{N},
1913Sum(),
1914[&](const std::vector<VarHandle>& v) { return B.load(v[0]); },
1915[&](const std::vector<VarHandle>& v) { return A.load(v[1], v[0]); },
1916{M});
1917LoopNest nest({C});
1918nest.prepareForCodegen();
1919StmtPtr s = LoopNest::sanitizeNames(IRSimplifier::simplify(nest.root_stmt()));
1920std::ostringstream oss;
1921oss << *s << "\n";
1922const std::string& expected_ir =
1923R"IR(
1924#CHECK: for (int i = 0; i < 16; i++) {
1925#CHECK: C[i] = B[i];
1926#CHECK: for (int j = 0; j < 32; j++) {
1927#CHECK: C[i] = (C[i]) + (A[i + 16 * j]);
1928#CHECK: }
1929#CHECK: }
1930)IR";
1931torch::jit::testing::FileCheck().run(expected_ir, oss.str());
1932}
1933} // namespace jit
1934} // namespace torch
1935